#include <cstdlib>
#include <iostream>
#include <fstream>
#include <vector>
#include <map>
#include <string>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/DataLoader.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
using namespace std;
//mylinux~> root -l TMVAtest.C\(\"myMethod1,myMethod2,myMethod3\"\)
int TMVAtest(TString myMethodList = "")
{
// This loads the library
TMVA::Tools::Instance();
std::map<std::string, int> Use;
Use["KNN"] = 1; // k-nearest neighbour method
//Fisher
Use["Fisher"] = 0;
Use["FisherG"] = 0;
Use["BoostedFisher"] = 0; // uses generalised MVA method boosting
//ANN
Use["MLP"] = 0; // Recommended ANN
Use["MLPBFGS"] = 0; // Recommended ANN with optional training method
Use["MLPBNN"] = 1; // Recommended ANN with BFGS training method and bayesian regulator
Use["CFMlpANN"] = 0; // Depreciated ANN from ALEPH
Use["TMlpANN"] = 0; // ROOT's own ANN
// Boosted Decision Trees
Use["BDT"] = 1; // uses Adaptive Boost
Use["BDTG"] = 0; // uses Gradient Boost
Use["BDTB"] = 0; // uses Bagging
Use["BDTD"] = 0; // decorrelation + Adaptive Boost
Use["BDTF"] = 0; // allow usage of fisher discriminant for node splitting
//**********HAVE NO RANDOM FOREST**************
std::cout << std::endl;
std::cout << "==> Start TMVAtest" << std::endl;
// Select methods
if (myMethodList != "") {
for (std::map<std::string, int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
std::vector<TString> mlist = TMVA::gTools().SplitString(myMethodList, ',');
for (UInt_t i = 0; i < mlist.size(); i++) {
std::string regMethod(mlist[i]);
if (Use.find(regMethod) == Use.end()) {
std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
for (std::map<std::string, int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
std::cout << std::endl;
return 1;
}
Use[regMethod] = 1;
}
}
ifstream inputfile_s("./test_00.txt");
if (!inputfile_s) {
std::cout << "ERROR: could not open s_data file" << std::endl;
exit(1);
}
ifstream inputfile_b("./test_01.txt");
if (!inputfile_b) {
std::cout << "ERROR: could not open p_data file" << std::endl;
exit(1);
}
//Create Tree Model
TTree* TreeS = new TTree("TreeS", "Tree for Signal and Background");
Double_t var1, var2,var3;
Double_t w;
TreeS->Branch("var1", &var1, "var1");
TreeS->Branch("var2", &var2, "var2");
TreeS->Branch("var3", &var3, "var3");
TreeS->Branch("w", &w, "w");
TTree* TreeB = TreeS->CloneTree(0);
//Read data
vector <string> vec_s;
vector <string> vec_b;
vector <double> length_s;
vector <double> length_b;
vector <double> width_s;
vector <double> width_b;
vector <double> test_s;
vector <double> test_b;
string temp_s;
string temp_b;
while (getline(inputfile_s, temp_s)) //signal
{
vec_s.push_back(temp_s);
}
for (auto it = vec_s.begin(); it != vec_s.end(); it++)
{
istringstream is(*it);
string s;
int pam = 0;
while (is >> s)
{
if (pam == 1)
{
double s_length = atof(s.c_str());
length_s.push_back(s_length);
}
if (pam == 2)
{
double s_width;
s_width = atof(s.c_str());
width_s.push_back(s_width);
}
if (pam == 3)
{
double s_test;
s_test = atof(s.c_str());
test_s.push_back(s_test);
}
pam++;
}
}
while (getline(inputfile_b, temp_b)) //background
{
vec_b.push_back(temp_b);
}
for (auto it = vec_b.begin(); it != vec_b.end(); it++)
{
istringstream is(*it);
string b;
int pam = 0;
while (is >> b)
{
if (pam == 1)
{
double b_length = atof(b.c_str());
length_b.push_back(b_length);
}
if (pam == 2)
{
double b_width;
b_width = atof(b.c_str());
width_b.push_back(b_width);
}
if (pam == 3)
{
double b_test;
b_test = atof(b.c_str());
test_b.push_back(b_test);
}
pam++;
}
}
for (int i = 0; i < length_s.size(); i++) //signal
{
w = 1.0;
var1 = length_s.at(i);
var2 = width_s.at(i);
var3 = test_s.at(i);
TreeS->Fill();
}
for (int i = 0; i < length_b.size(); i++) //background
{
w = 1.0;
var1 = length_s.at(i);
var2 = width_s.at(i);
var3 = test_b.at(i);
TreeB->Fill();
}
//to save ntuples,histogram,etc.
TString outfileName("TMVAoutput.root");
TFile* outputFile = TFile::Open(outfileName, "RECREATE");
//*********HAVETOCREATE*************
TMVA::Factory* factory = new TMVA::Factory("TMVAtest", outputFile,
"!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
TMVA::DataLoader* dataloader = new TMVA::DataLoader("dataset");
dataloader->AddVariable("var1", "Variable 1", "units", 'F');
dataloader->AddVariable("var2", "Variable 2", "units", 'F');
dataloader->AddVariable("var3", "Variable 3", "units", 'F');
//there can add "spectator variables"
dataloader->AddSignalTree(TreeS);
dataloader->AddBackgroundTree(TreeB);
dataloader->SetSignalWeightExpression("w");
dataloader->SetBackgroundWeightExpression("w");
//Apply additional cuts on the signal and background samples (can be different)
TCut mycuts = ""; //for example :TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
TCut mycutb = ""; //for example : TCut mycutb = "abs(var1)<0.5";
// dataloader->PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
// dataloader->PrepareTrainingAndTestTree( mycut,
// "NSigTrain=3000:NBkgTrain=3000:NSigTest=3000:NBkgTest=3000:SplitMode=Random:!V" );
dataloader->PrepareTrainingAndTestTree(mycuts, mycutb, "nTrain_Signal=25:nTrain_Background=25:SplitMode=Random:NormMode=NumEvents:!V");
//Book MVA Methods
//KNN
if (Use["KNN"])
factory->BookMethod(dataloader, TMVA::Types::kKNN, "KNN",
"H:nkNN=20:ScaleFrac=0.8:SigmaFact=1.0:Kernel=Gaus:UseKernel=F:UseWeight=T:!Trim");
//Fisher
if (Use["Fisher"])// Fisher discriminant (same as LD)
factory->BookMethod(dataloader, TMVA::Types::kFisher, "Fisher", "H:!V:Fisher:VarTransform=None:CreateMVAPdfs:PDFInterpolMVAPdf=Spline2:NbinsMVAPdf=50:NsmoothMVAPdf=10");
if (Use["FisherG"])// Fisher with Gauss-transformed input variables
factory->BookMethod(dataloader, TMVA::Types::kFisher, "FisherG", "H:!V:VarTransform=Gauss");
if (Use["BoostedFisher"])// Composite classifier: ensemble (tree) of boosted Fisher classifiers
factory->BookMethod(dataloader, TMVA::Types::kFisher, "BoostedFisher",
"H:!V:Boost_Num=20:Boost_Transform=log:Boost_Type=AdaBoost:Boost_AdaBoostBeta=0.2:!Boost_DetailedMonitoring");
// Boosted Decision Trees
if (Use["BDTG"]) // Gradient Boost
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG",
"!H:!V:NTrees=1000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
if (Use["BDT"]) // Adaptive Boost
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT",
"!H:!V:NTrees=850:MinNodeSize=2.5%:MaxDepth=3:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20");
if (Use["BDTB"]) // Bagging
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTB",
"!H:!V:NTrees=400:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
if (Use["BDTD"]) // Decorrelation + Adaptive Boost
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTD",
"!H:!V:NTrees=400:MinNodeSize=5%:MaxDepth=3:BoostType=AdaBoost:SeparationType=GiniIndex:nCuts=20:VarTransform=Decorrelate");
if (Use["BDTF"]) // Allow Using Fisher discriminant in node splitting for (strong) linearly correlated variables
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTF",
"!H:!V:NTrees=50:MinNodeSize=2.5%:UseFisherCuts:MaxDepth=3:BoostType=AdaBoost:AdaBoostBeta=0.5:SeparationType=GiniIndex:nCuts=20");
// Train MVAs using the set of training events
factory->TrainAllMethods();
// Evaluate all MVAs using the set of test events
factory->TestAllMethods();
// Evaluate and compare performance of all configured MVAs
factory->EvaluateAllMethods();
//Save the ouput
outputFile->Close();
std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
std::cout << "==> TMVAtest is done!" << std::endl;
delete factory;
delete dataloader;
// Launch the GUI for the root macros
if (!gROOT->IsBatch()) TMVA::TMVAGui(outfileName);
return 0;
}
int main(int argc, char** argv)
{
// Select methods (don't look at this code - not of interest)
TString methodList;
for (int i = 1; i < argc; i++) {
TString regMethod(argv[i]);
if (regMethod == "-b" || regMethod == "--batch") continue;
if (!methodList.IsNull()) methodList += TString(",");
methodList += regMethod;
}
return TMVAtest(methodList);
}
Hi, @jalopezg ;
These are my source codes ,and I used Iris Flowers Data for the data file(ht( )tp://en.wikipedia.org/wiki/Iris_flower_data_set).I added the function of reading data in TXT files on the basis of TMVAClassification, and selected some methods I could use.
Thank you for your help!
Cheers,
Kevin