/* * iterativeBayesTestThermal.C * * Created on: Apr 16, 2021 * Author: John Russell * * This script tests how quickly iterating priors on naive Bayes converges to removing all thermal background samples. */ #include "AnitaDataset.h" #include "AnitaTMVA.h" bool isMax(double lhWAISHPol, double lhWAISVPol, double lhHiCal2A, double lhHiCal2B, double lhIceMC, double lhAboveHorizontalThermal, double priorWAISHPol = 1, double priorWAISVPol = 1, double priorHiCal2A = 1, double priorHiCal2B = 1, double priorIceMC = 1, double priorAboveHorizontalThermal = 1) { double sigArr[5] = {lhWAISHPol * priorWAISHPol, lhWAISVPol * priorWAISVPol, lhHiCal2A * priorHiCal2A, lhHiCal2B * priorHiCal2B, lhIceMC * priorIceMC}; double maxSig = TMath::MaxElement(5, sigArr); double maxBkg = priorAboveHorizontalThermal * lhAboveHorizontalThermal; return maxSig > maxBkg; } void iterativeBayesTestThermal(const char * responseType = "StokesHybridPlus", int numIter = 5) { // Create TChains of classifications. TChain WAISHPolChain("responseTree"), WAISVPolChain("responseTree"), HiCal2AChain("responseTree"), HiCal2BChain("responseTree"), iceMCChain("responseTree"); TChain RCPChain("responseTree"), LCPChain("responseTree"), payloadBlastChain("responseTree"), sunChain("responseTree"), aboveHorizontalThermalChain("responseTree"); TChain survivingChain("responseTree"); WAISHPolChain.Add("~/jwrussWork/responseValues/WAISHPol/purified/*"); WAISVPolChain.Add("~/jwrussWork/responseValues/WAISVPol/purified/*"); HiCal2AChain.Add("~/jwrussWork/responseValues/HiCal2A/purified/*"); HiCal2BChain.Add("~/jwrussWork/responseValues/HiCal2B/purified/*"); iceMCChain.Add("~/jwrussWork/responseValues/iceMC/purified/*"); RCPChain.Add("~/jwrussWork/responseValues/other/RCP/*"); LCPChain.Add("~/jwrussWork/responseValues/other/LCP/*"); payloadBlastChain.Add("~/jwrussWork/responseValues/other/payloadBlast/*"); sunChain.Add("~/jwrussWork/responseValues/other/sun/*"); aboveHorizontalThermalChain.Add("~/jwrussWork/responseValues/aboveHorizontal/thermal/*"); survivingChain.Add("~/jwrussWork/responseValues/belowHorizontal/surviving/*"); // Need TChains of the summaries to use iceMC weights, then friend it to the TChains above. TChain WAISHPolSumChain("sampleA4"), WAISVPolSumChain("sampleA4"), HiCal2ASumChain("sampleA4"), HiCal2BSumChain("sampleA4"), iceMCSumChain("sampleA4"); TChain RCPSumChain("sampleA4"), LCPSumChain("sampleA4"), payloadBlastSumChain("sampleA4"), sunSumChain("sampleA4"), aboveHorizontalThermalSumChain("sampleA4"); TChain survivingSumChain("sampleA4"); WAISHPolSumChain.Add("~/jwrussWork/eventSummaries/WAISHPol/purified/*"); WAISVPolSumChain.Add("~/jwrussWork/eventSummaries/WAISVPol/purified/*"); HiCal2ASumChain.Add("~/jwrussWork/eventSummaries/HiCal2A/purified/*"); HiCal2BSumChain.Add("~/jwrussWork/eventSummaries/HiCal2B/purified/*"); iceMCSumChain.Add("~/jwrussWork/eventSummaries/iceMC/purified/*"); RCPSumChain.Add("~/jwrussWork/eventSummaries/other/RCP/*"); LCPSumChain.Add("~/jwrussWork/eventSummaries/other/LCP/*"); payloadBlastSumChain.Add("~/jwrussWork/eventSummaries/other/payloadBlast/*"); sunSumChain.Add("~/jwrussWork/eventSummaries/other/sun/*"); aboveHorizontalThermalSumChain.Add("~/jwrussWork/eventSummaries/aboveHorizontal/thermal/*"); survivingSumChain.Add("~/jwrussWork/eventSummaries/belowHorizontal/surviving/*"); WAISHPolChain.AddFriend(& WAISHPolSumChain); WAISVPolChain.AddFriend(& WAISVPolSumChain); HiCal2AChain.AddFriend(& HiCal2ASumChain); HiCal2BChain.AddFriend(& HiCal2BSumChain); iceMCChain.AddFriend(& iceMCSumChain); RCPChain.AddFriend(& RCPSumChain); LCPChain.AddFriend(& LCPSumChain); payloadBlastChain.AddFriend(& payloadBlastSumChain); sunChain.AddFriend(& sunSumChain); aboveHorizontalThermalChain.AddFriend(& aboveHorizontalThermalSumChain); survivingChain.AddFriend(& survivingSumChain); // Point to the TChains' contents. TString objNameStr = TString::Format("%sCDF", responseType); const char * objName = objNameStr.Data(); // Determine initial sample sizes. double numWAISHPol0 = WAISHPolChain.GetEntries(); double numWAISVPol0 = WAISVPolChain.GetEntries(); double numHiCal2A0 = HiCal2AChain.GetEntries(); double numHiCal2B0 = HiCal2BChain.GetEntries(); TH1D htempIceMC("htempIceMC", "htempIceMC", 10000, 0, 0.22); iceMCChain.Draw("Entry$ >> htempIceMC", "mc.weight", "goff"); double sumIceMC0 = htempIceMC.Integral(); double numRCP0 = RCPChain.GetEntries(); double numLCP0 = LCPChain.GetEntries(); double numPayloadBlast0 = payloadBlastChain.GetEntries(); double numSun0 = sunChain.GetEntries(); double numAboveHorizontalThermal0 = aboveHorizontalThermalChain.GetEntries(); std::cout << std::endl; std::cout << "Iteration: 0" << std::endl; std::cout << std::endl; std::cout << "Initial number of WAIS Hpol: " << WAISHPolChain.GetEntries(); std::cout << ", WAIS Vpol: " << WAISVPolChain.GetEntries(); std::cout << ", HiCal2A: " << HiCal2AChain.GetEntries(); std::cout << ", HiCal2B: " << HiCal2BChain.GetEntries(); std::cout << ", weight of iceMC: " << iceMCChain.GetEntries(); std::cout << ", number of RCP: " << RCPChain.GetEntries(); std::cout << ", LCP: " << LCPChain.GetEntries(); std::cout << ", payload blast: " << payloadBlastChain.GetEntries(); std::cout << ", sun: " << sunChain.GetEntries(); std::cout << ", above horizontal thermal: " << aboveHorizontalThermalChain.GetEntries(); std::cout << ", surviving sample : " << survivingChain.GetEntries() << std::endl; std::cout << std::endl; TString initialCondStr = TString::Format("isMax(%s.WAISHPol, %s.WAISVPol, %s.HiCal2A, %s.HiCal2B, %s.iceMC, %s.aboveHorizontalThermal)", objName, objName, objName, objName, objName, objName); double numWAISHPol = WAISHPolChain.GetEntries(initialCondStr); double numWAISVPol = WAISVPolChain.GetEntries(initialCondStr); double numHiCal2A = HiCal2AChain.GetEntries(initialCondStr); double numHiCal2B = HiCal2BChain.GetEntries(initialCondStr); iceMCChain.Draw("Entry$ >> htempIceMC", "mc.weight * (" + initialCondStr + ")", "goff"); double sumIceMC = htempIceMC.Integral(); double numRCP = RCPChain.GetEntries(initialCondStr); double numLCP = LCPChain.GetEntries(initialCondStr); double numPayloadBlast = payloadBlastChain.GetEntries(initialCondStr); double numSun = sunChain.GetEntries(initialCondStr); double numAboveHorizontalThermal = aboveHorizontalThermalChain.GetEntries(initialCondStr); int numSurviving = survivingChain.GetEntries(initialCondStr); double priorWAISHPol = numWAISHPol / numWAISHPol0; double priorWAISVPol = numWAISVPol / numWAISVPol0; double priorHiCal2A = numHiCal2A / numHiCal2A0; double priorHiCal2B = numHiCal2B / numHiCal2B0; double priorIceMC = sumIceMC / sumIceMC0; double priorRCP = (numRCP0 - numRCP) / numRCP0; double priorLCP = (numLCP0 - numLCP) / numLCP0; double priorPayloadBlast = (numPayloadBlast0 - numPayloadBlast) / numPayloadBlast0; double priorSun = (numSun0 - numSun) / numSun0; double priorAboveHorizontalThermal = (numAboveHorizontalThermal0 - numAboveHorizontalThermal) / numAboveHorizontalThermal0; TString updatingCondStr = " && isMax(%s.WAISHPol, %s.WAISVPol, %s.HiCal2A, %s.HiCal2B, %s.iceMC, %s.aboveHorizontalThermal, %g, %g, %g, %g, %g, %g)"; TString iterCondStr = initialCondStr + TString::Format(updatingCondStr, priorWAISHPol, objName, priorWAISVPol, objName, priorHiCal2A, objName, priorHiCal2B, objName, priorIceMC, objName, priorAboveHorizontalThermal, objName); // Setting up and executing iterative loop. for (int i = 1; i <= numIter; ++i) { std::cout << "Iteration: " << i << std::endl; std::cout << std::endl; std::cout << "Remaining number of WAIS Hpol: " << numWAISHPol; std::cout << ", WAIS Vpol: " << numWAISVPol; std::cout << ", HiCal2A: " << numHiCal2A; std::cout << ", HiCal2B: " << numHiCal2B; std::cout << ", weight of iceMC: " << sumIceMC; std::cout << ", number of RCP: " << numRCP; std::cout << ", LCP: " << numLCP; std::cout << ", payload blast: " << numPayloadBlast; std::cout << ", sun: " << numSun; std::cout << ", above horizontal thermal: " << numAboveHorizontalThermal; std::cout << ", surviving sample : " << numSurviving << std::endl; std::cout << std::endl; std::cout << "Posterior probability for WAIS Hpol: " << priorWAISHPol; std::cout << ", WAIS Vpol: " << priorWAISVPol; std::cout << ", HiCal2A: " << priorHiCal2A; std::cout << ", HiCal2B: " << priorHiCal2B; std::cout << ", iceMC: " << priorIceMC; std::cout << ", RCP: " << priorRCP; std::cout << ", LCP: " << priorLCP; std::cout << ", payload blast: " << priorPayloadBlast; std::cout << ", sun: " << priorSun; std::cout << ", above horizontal thermal: " << priorAboveHorizontalThermal << std::endl; std::cout << std::endl; numWAISHPol = WAISHPolChain.GetEntries(iterCondStr); numWAISVPol = WAISVPolChain.GetEntries(iterCondStr); numHiCal2A = HiCal2AChain.GetEntries(iterCondStr); numHiCal2B = HiCal2BChain.GetEntries(iterCondStr); iceMCChain.Draw("Entry$ >> htempIceMC", "mc.weight * (" + iterCondStr + ")", "goff"); sumIceMC = htempIceMC.Integral(); numRCP = RCPChain.GetEntries(iterCondStr); numLCP = LCPChain.GetEntries(iterCondStr); numPayloadBlast = payloadBlastChain.GetEntries(iterCondStr); numSun = sunChain.GetEntries(iterCondStr); numAboveHorizontalThermal = aboveHorizontalThermalChain.GetEntries(iterCondStr); numSurviving = survivingChain.GetEntries(iterCondStr); priorWAISHPol = numWAISHPol / numWAISHPol0; priorWAISVPol = numWAISVPol / numWAISVPol0; priorHiCal2A = numHiCal2A / numHiCal2A0; priorHiCal2B = numHiCal2B / numHiCal2B0; priorIceMC = sumIceMC / sumIceMC0; priorRCP = (numRCP0 - numRCP) / numRCP0; priorLCP = (numLCP0 - numLCP) / numLCP0; priorPayloadBlast = (numPayloadBlast0 - numPayloadBlast) / numPayloadBlast0; priorSun = (numSun0 - numSun) / numSun0; priorAboveHorizontalThermal = (numAboveHorizontalThermal0 - numAboveHorizontalThermal) / numAboveHorizontalThermal0; iterCondStr += TString::Format(updatingCondStr, priorWAISHPol, objName, priorWAISVPol, objName, priorHiCal2A, objName, priorHiCal2B, objName, priorIceMC, objName, priorAboveHorizontalThermal, objName); } }