#ifndef TMVACLASSIFICATION_C #define TMVACLASSIFICATION_C #include #include #include #include #include #include "root/TFile.h" #include "root/TTree.h" #include "root/TMVA/MethodBase.h" #include "root/TMVA/MethodCuts.h" #include "root/TMVA/Factory.h" #include "root/TMVA/DataLoader.h" #include "root/TMVA/Types.h" #include "root/TMVA/TMVAGui.h" void TMVAClassification() { TFile* outFile = new TFile("../Data/TMVA/Classification.root", "RECREATE"); TFile* sigSrc = new TFile("../Data/tr_ph_selected.root"); TFile* bkgSrc = new TFile("../Data/tr_ph_selected_KsKl.root"); //Get trees from files TTree* sigTree = (TTree*)sigSrc->Get("selected"); TTree* bkgTree = (TTree*)bkgSrc->Get("selected"); //Weights of trees Double_t sigWeight = 1.0; Double_t bkgWeight = 1.0; //Create Factory object --- the only TMVA object to contact with TMVA::Factory* factory = new TMVA::Factory( "EtotPtotClf", outFile, "AnalysisType=Classification:Transformations=N,G,D" ); //Create DataLoader TMVA::DataLoader* dataloader = new TMVA::DataLoader( "dataset" ); //Add trees to the analysis dataloader->AddSignalTree( sigTree, sigWeight ); dataloader->AddBackgroundTree( bkgTree, bkgWeight ); //Add varibles of interest dataloader->AddVariable( "eTotal", "Total Energy of Photons", "MeV", 'F' ); dataloader->AddVariable( "pTotal", "Total Momentum of Photons", "MeV/c", 'F' ); //Preparing trees for processing dataloader->PrepareTrainingAndTestTree("", "SplitMode=Random"); //Book method!!! //Let's use k nearest neighbors and Rectangular cuts! What's up? =) factory->BookMethod( dataloader, TMVA::Types::kKNN, "kNN", "nkNN=40"); TMVA::MethodBase* methodCutsBase = factory->BookMethod( dataloader, TMVA::Types::kCuts, "Cuts", "FitMethod=GA:EffMethod=EffSel" ); TMVA::MethodCuts* methodCuts = dynamic_cast(methodCutsBase); //Train methods factory->TrainAllMethods(); //Get cuts from Cuts method std::vector minCuts, maxCuts; Double_t trueEffS = methodCuts->GetCuts( 0.5, minCuts, maxCuts ); //Test methods factory->TestAllMethods(); //Evaluate methods factory->EvaluateAllMethods(); //Clean out delete factory; delete dataloader; //Close file outFile->Close(); //Launch gui if ( !gROOT->IsBatch() ) TMVA::TMVAGui( "../Data/TMVA/Classification.root" ); //for ( auto cut: minCuts ) //{ // std::cout << cut << "\n"; //} } #endif