#include #include #include "TFile.h" #include "TTree.h" #include "TRandom3.h" #include "TMVA/Factory.h" #include "TMVA/DataLoader.h" #include "TMVA/Tools.h" using namespace std; int main() { //-------------------------------------------- // 1 初始化 TMVA //-------------------------------------------- TMVA::Tools::Instance(); TFile* outputFile = TFile::Open("tmva_output.root","RECREATE"); TMVA::Factory *factory = new TMVA::Factory( "NeutronSpectrum", outputFile, "!V:Color:Transformations=N:AnalysisType=Regression" ); TMVA::DataLoader *loader = new TMVA::DataLoader("dataset"); //-------------------------------------------- // 2 定义输入变量 (10个探测器) //-------------------------------------------- for(int i=0;i<10;i++) { loader->AddVariable(Form("det%d",i),'F'); } //-------------------------------------------- // 3 定义输出变量 (示例5个) //-------------------------------------------- for(int i=0;i<5;i++) { loader->AddTarget(Form("spec%d",i)); } //-------------------------------------------- // 4 创建训练数据 //-------------------------------------------- TTree *tree = new TTree("data","training data"); float det[10]; float spec[5]; for(int i=0;i<10;i++) tree->Branch(Form("det%d",i),&det[i]); for(int i=0;i<5;i++) tree->Branch(Form("spec%d",i),&spec[i]); TRandom3 rnd(0); for(int n=0;n<10000;n++) { for(int i=0;i<10;i++) det[i]=rnd.Uniform(0,100); // 模拟一个简单关系 for(int i=0;i<5;i++) spec[i]=(det[0]+det[1]+det[2])*0.1 + rnd.Gaus(0,1); tree->Fill(); } //-------------------------------------------- // 5 加载数据 //-------------------------------------------- loader->AddRegressionTree(tree,1.0); loader->PrepareTrainingAndTestTree( "", "nTrain_Regression=8000:" "nTest_Regression=2000:" "SplitMode=Random:" "NormMode=NumEvents:" "!V" ); //-------------------------------------------- // 6 定义 DNN 网络结构 //-------------------------------------------- TString layout = // "Layout=RELU|128,RELU|64,LINEAR"; "Layout=DENSE|128|RELU,DENSE|64|RELU,DENSE|5|LINEAR:"; TString training = "TrainingStrategy=" "LearningRate=1e-3," "Momentum=0.9," "Repetitions=1," "ConvergenceSteps=20," "BatchSize=64," "TestRepetitions=1," "WeightDecay=1e-4," "Regularization=None," "DropConfig=0.0+0.0+0.0"; TString dnnOptions = "!H:" "V:" "ErrorStrategy=SUMOFSQUARES:" "VarTransform=N:" + layout + ":" + training + ":Architecture=CPU"; factory->BookMethod( loader, TMVA::Types::kDNN, "DNN", dnnOptions ); //-------------------------------------------- // 7 训练 //-------------------------------------------- factory->TrainAllMethods(); //-------------------------------------------- // 8 测试 //-------------------------------------------- factory->TestAllMethods(); //-------------------------------------------- // 9 评估 //-------------------------------------------- factory->EvaluateAllMethods(); //-------------------------------------------- outputFile->Close(); cout<<"Training finished"<