#include #include #include #include #include "TChain.h" #include "TFile.h" #include "TTree.h" #include "TString.h" #include "TObjString.h" #include "TSystem.h" #include "TROOT.h" #include "TMVA/CrossValidation.h" #include "TMVA/DataLoader.h" #include "TMVA/Factory.h" #include "TMVA/Tools.h" #include "TMVA/TMVAGui.h" TTree *genTree(TTree *track, Int_t nPoints) { TTree* tr = track->CloneTree(); Int_t eventID=0; TBranch *_eventID=tr->Branch("eventID", &eventID, "eventID/I"); for(Int_t n=0; nFill(); } track->ResetBranchAddresses(); return tr; } int TMVACrossValidation(bool useRandomSplitting = false) { TMVA::Tools::Instance(); const char* input1="/besfs5/groups/tauqcd/souvik/mc/signal_mc/ppbargamma_training/4230/ana/root/"; const char* input2="/besfs5/groups/tauqcd/souvik/mc/signal_mc/pppi0/4230/bdt_training/isrgamma/ana/inc_mc/root/"; const char* input3="/besfs5/users/souvik/analysis/tagged_ppbargamma/bdtg/cross_validation_check/new_training/root/"; const char* input4="/besfs5/groups/tauqcd/souvik/mc/signal_mc/etappbar/bdt_training/4230/ana/inc_mc/root/"; const char* input5="/besfs5/groups/tauqcd/souvik/mc/signal_mc/pipi/bdt_training/4230/ana/inc_mc/root/"; const char* input6="/besfs5/groups/tauqcd/souvik/mc/signal_mc/dimuon/bdt_training/"; const char* input7="/besfs5/groups/tauqcd/souvik/mc/signal_mc/Bhaba/bdt_training/"; const char* input8="/besfs5/groups/tauqcd/souvik/mc/signal_mc/BesTwoGam/bdt_training/"; const char* sig="*.root"; const char* bkg="*.root"; TChain *ch1 = new TChain("track"); ch1->Add(Form("%s/%s", input1, sig)); TTree* tr1=ch1; Int_t nevents_sig=ch1->GetEntries(); TChain *ch2 = new TChain("track"); ch2->Add(Form("%s/%s", input2, bkg)); ch2->Add(Form("%s/%s", input3, bkg)); ch2->Add(Form("%s/%s", input4, bkg)); ch2->Add(Form("%s/%s", input5, bkg)); ch2->Add(Form("%s/%s", input6, bkg)); ch2->Add(Form("%s/%s", input7, bkg)); ch2->Add(Form("%s/%s", input8, bkg)); TTree* tr2=ch2; Int_t nevents_bkg=ch2->GetEntries(); TTree *sigTree = genTree(tr1, nevents_sig); TTree *bkgTree = genTree(tr2, nevents_bkg); TString outfileName("TMVA.root"); TFile *outputFile = TFile::Open(outfileName, "RECREATE"); TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset"); dataloader->AddVariable( "shwrI", 'F' ); dataloader->AddVariable( "pCosInppc", 'F' ); dataloader->AddVariable( "ppbarLabAng", 'F' ); dataloader->AddVariable( "ppgFitchi2", 'F' ); dataloader->AddVariable( "ppmFitchi2", 'F' ); dataloader->AddSpectator( "eventID", "Spectator5", "", 'I'); dataloader->AddSpectator( "ppbarInvMass", "Spectator 3", "units", 'F' ); Double_t signalWeight = 1.0; Double_t backgroundWeight = 1.0; dataloader->AddSignalTree(sigTree, signalWeight); dataloader->AddBackgroundTree(bkgTree, backgroundWeight); TCut mycuts = "ppgFitchi2<100 && ppbarInvMass>1.8 && ppbarInvMass<4.5"; TCut mycutb = "ppgFitchi2<100 && ppbarInvMass>1.8 && ppbarInvMass<4.5"; dataloader->PrepareTrainingAndTestTree(mycuts, mycutb, "nTest_Signal=1" ":nTest_Background=1" ":SplitMode=Random" ":NormMode=NumEvents" ":!V"); UInt_t numFolds = 2; TString analysisType = "Classification"; TString splitType = (useRandomSplitting) ? "Random" : "Deterministic"; TString splitExpr = (!useRandomSplitting) ? "int(fabs([eventID]))%int([NumFolds])" : ""; TString cvOptions = Form("!V" ":!Silent" ":ModelPersistence" ":AnalysisType=%s" ":SplitType=%s" ":NumFolds=%i" ":SplitExpr=%s", analysisType.Data(), splitType.Data(), numFolds, splitExpr.Data()); TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions}; cv.BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:MinNodeSize=2.5%:BoostType=Grad" ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=50" ":MaxDepth=3"); cv.Evaluate(); size_t iMethod = 0; for (auto && result : cv.GetResults()) { std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue("MethodName") << std::endl; for (UInt_t iFold = 0; iFoldClose(); std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl; std::cout << "==> TMVACrossValidation is done!" << std::endl; if (!gROOT->IsBatch()) { cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDTG"); //cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDT"); TMVA::TMVAGui(outfileName); } return 0; } int main(int argc, char **argv) { TMVACrossValidation(); }