import ROOT
import argparse
import math
import numpy as np

def fillEfficiency(histo_passed, histo_total, nBins, mass, charge, index):

	fileName = "/user/vannerom/FCP/BatchArea/CMSSW_10_2_3/src/AnalyzerArea/FCPTreeMaker/M"+mass+"_Q"+charge+"_"
	if index==0: fileName += "analysisSample.root"
	elif index==1: fileName += "muonHack.root"
	elif index==2: fileName += "elossDown.root"
	elif index==3: fileName += "elossUp.root"
	else: print("Wrong file index")
	print("Filling file "+fileName)

	f = ROOT.TFile(fileName)
	myTree = f.Get("demo/tree")
	nEntries = myTree.GetEntries()
	print("nEntries = "+str(nEntries))
	
	h_passed_copy = histo_passed.Clone()
	h_total_copy = histo_total.Clone()
	
	for event in myTree:
		for i in range(nBins): h_total_copy.Fill(float(i),event.weight)
	
		b_L1 = False
		b_HLT = False
		b_pT = False
		b_eta = False
		b_dxy = False
		b_dz = False
		b_nHits = False
		b_nPixelHits = False
		b_isMuonMatched = False
		b_PFIso = False
		b_isGlobal = False
		b_timeIP = False
		b_alphaMax = False
	
		#for(size_t i=0; i<track_pt->size() i++){
		for i in range(event.track_pt.size()):
			bTmp_L1 = False
			bTmp_HLT = False
			bTmp_pT = False
			bTmp_eta = False
			bTmp_dxy = False
			bTmp_dz = False
			bTmp_nHits = False
			bTmp_nPixelHits = False
			bTmp_isMuonMatched = False
			bTmp_PFIso = False
			bTmp_isGlobal = False
			bTmp_timeIP = False
			bTmp_alphaMax = False
		
			if (event.L1_SingleMu22 or event.L1_SingleMu25): bTmp_L1 = True
			if (event.HLT_Mu50 or event.HLT_OldMu100 or event.HLT_TkMu100): bTmp_HLT = True
			if event.track_pt[i]>55: bTmp_pT = True
			if abs(event.track_eta[i])<1.5: bTmp_eta = True
			if abs(event.track_dxy[i])<0.1: bTmp_dxy = True
			if abs(event.track_dz[i])<0.5: bTmp_dz = True
			if event.track_nHits[i]>5: bTmp_nHits = True
			if event.track_nPixelHits[i]>1: bTmp_nPixelHits = True
			# Find matching global muon
			muPt = 0
			muEta = 0
			muPhi = 0
			PFIso = 0
			timeIP = 0
			nDof = 0
			deltaRmin = 1
			isGlobal = False
			for imu in range(event.muon_pt.size()):
				if not event.muon_isLoose[imu]: continue
				if getDeltaR(event.track_eta[i],event.track_phi[i],event.muon_eta[imu],event.muon_phi[imu])<deltaRmin and getDeltaR(event.track_eta[i],event.track_phi[i],event.muon_eta[imu],event.muon_phi[imu])<0.05:
					deltaRmin = getDeltaR(event.track_eta[i],event.track_phi[i],event.muon_eta[imu],event.muon_phi[imu])
					PFIso = event.muon_PFIso[imu]
					timeIP = event.muon_timeAtIpInOut[imu]
					nDof = event.muon_time_nDof[imu]
					isGlobal = event.muon_isGlobal[imu]
					muPt = event.muon_pt[imu]
					muEta = event.muon_eta[imu]
					muPhi = event.muon_phi[imu]
			
			if deltaRmin<1: bTmp_isMuonMatched = True
			if PFIso<0.15: bTmp_PFIso = True
			if isGlobal: bTmp_isGlobal = True
			if timeIP>0 and nDof>7: bTmp_timeIP = True
			if event.track_alphaMax[i]<2.8: bTmp_alphaMax = True
			
			if bTmp_pT and not b_pT:
				h_passed_copy.Fill(float(0),event.weight)
				b_pT = True
				if bTmp_eta and not b_eta:
					h_passed_copy.Fill(float(1),event.weight)
					b_eta = True
					if bTmp_dxy and not b_dxy:
						h_passed_copy.Fill(float(2),event.weight)
						b_dxy = True
						if bTmp_dz and not b_dz:
							h_passed_copy.Fill(float(3),event.weight)
							b_dz = True
							if bTmp_nHits and not b_nHits:
								h_passed_copy.Fill(float(4),event.weight)
								b_nHits = True
								if bTmp_nPixelHits and not b_nPixelHits:
									h_passed_copy.Fill(float(5),event.weight)
									b_nPixelHits = True
									if bTmp_alphaMax and not b_alphaMax:
										h_passed_copy.Fill(float(6),event.weight)
										b_alphaMax = True
										if bTmp_isMuonMatched and not b_isMuonMatched:
											h_passed_copy.Fill(float(7),event.weight)
											b_isMuonMatched = True
											if bTmp_PFIso and not b_PFIso:
												h_passed_copy.Fill(float(8),event.weight)
												b_PFIso = True
												if bTmp_isGlobal and not b_isGlobal:
													h_passed_copy.Fill(float(9),event.weight)
													b_isGlobal = True
													if bTmp_timeIP and not b_timeIP:
														h_passed_copy.Fill(float(10),event.weight)
														b_timeIP = True
														if bTmp_L1 and not b_L1:
															h_passed_copy.Fill(float(11),event.weight)
															b_L1 = True
															if bTmp_HLT and not b_HLT:
																h_passed_copy.Fill(float(12),event.weight)
																b_HLT = True
	
	histo_passed.Add(h_passed_copy)
	histo_total.Add(h_total_copy)

def DrawEfficiency(eff_noHack, eff, eff_uncert, nBins, xLabel, mass, charge):

	ROOT.gStyle.SetLegendBorderSize(0)
	ROOT.gStyle.SetOptStat(0)
	ROOT.gROOT.SetBatch(True)
	
	c = ROOT.TCanvas("c", "", 800, 950)
	c.cd()
	
	ROOT.gPad.Update()
	
	UpperPad = ROOT.TPad("UpperPad", "UpperPad", 0, 0.3, 1, 1)
	UpperPad.SetTopMargin(0.05)
	#UpperPad.SetBottomMargin(0.15)
	UpperPad.SetBottomMargin(0.02)
	UpperPad.SetRightMargin(0.05)
	UpperPad.SetLeftMargin(0.13)
	UpperPad.Draw()
	
	UpperPad.cd()
	
	ROOT.gPad.SetTicky(1)

	g_canvas = ROOT.TGraph(nBins)
	g_canvas.Draw("AP")
	g_canvas.SetTitle("")
	g_canvas.GetYaxis().SetTitle("Selection efficiency")
	g_canvas.GetYaxis().SetTitleSize(30)
	g_canvas.GetYaxis().SetTitleFont(43)
	g_canvas.GetYaxis().SetTitleOffset(1.4)
	g_canvas.GetYaxis().SetLabelSize(25)
	g_canvas.GetYaxis().SetLabelFont(43)
	g_canvas.GetYaxis().SetNdivisions(6,5,0)
	if charge == '1Over3':
		g_canvas.SetMaximum(50)
		g_canvas.SetMinimum(2e-04)
		ROOT.gPad.SetLogy()
	else:
		g_canvas.SetMaximum(1.18)
		g_canvas.SetMinimum(0)
	#g_canvas.Draw("AP")
	c.Update()
	
	# Label bins with selection criteria
	ax = g_canvas.GetHistogram().GetXaxis()
	ax.Set(nBins,0,nBins)
	ax.SetBinLabel(1,"p_{T}")
	ax.SetBinLabel(2,"#eta")
	ax.SetBinLabel(3,"d_{xy}")
	ax.SetBinLabel(4,"d_{z}")
	ax.SetBinLabel(5,"# hits")
	ax.SetBinLabel(6,"# pixel hits")
	ax.SetBinLabel(7,"#alpha_{max}")
	ax.SetBinLabel(8,"Loose muon")
	ax.SetBinLabel(9,"Isolation")
	ax.SetBinLabel(10,"Global muon")
	ax.SetBinLabel(11,"Timing")
	ax.SetBinLabel(12,"L1")
	ax.SetBinLabel(13,"HLT")
	
	g_canvas.GetXaxis().SetLabelSize(0)
	
	eff_noHack.Draw("PZ")
	eff_noHack.SetLineWidth(2)
	eff_noHack.SetLineColor(ROOT.kRed)
	eff_noHack.SetMarkerColor(ROOT.kRed)
	
	eff.Draw("PZ")
	eff.SetLineWidth(2)
	eff.SetLineColor(ROOT.kBlack)
	eff.SetMarkerColor(ROOT.kBlack)
	
	eff_uncert.Draw("2")
	eff_uncert.SetLineColor(0)
	eff_uncert.SetFillStyle(3002)

	legCharge = ''
	if charge=='1': legCharge = '#it{e}'
	elif charge=='0p9': legCharge = '0.9 #it{e}'
	elif charge=='0p8': legCharge = '0.8 #it{e}'
	elif charge=='2Over3': legCharge = '2#it{e}/3'
	elif charge=='0p5': legCharge = '#it{e}/2'
	elif charge=='1Over3': legCharge = '#it{e}/3'
	else: print("Not a valid charge point")
	
	leg = ROOT.TLegend(0.5,0.7,0.9,0.9)
	leg.SetHeader("FCP M = "+mass+" GeV, Q = "+legCharge)
	leg.SetTextSize(0.035)
	leg.AddEntry(eff_noHack,"Tracker sim.","lep")
	leg.AddEntry(eff,"Tracker+Muon sim.","lep")
	leg.AddEntry(eff_uncert,"Muon sim. uncertainty","f")
	leg.Draw()

	c.cd()
	
	LowerPad = ROOT.TPad("LowerPad", "LowerPad", 0, 0, 1, 0.3)
	LowerPad.Draw()
	LowerPad.cd()
	LowerPad.SetTopMargin(0.02)
	LowerPad.SetBottomMargin(0.4)
	LowerPad.SetRightMargin(0.05)
	LowerPad.SetLeftMargin(0.13)

	ROOT.gPad.SetTicky(1)

	x = np.linspace(0.5,12.5,nBins)
	exl = 0.5*np.ones(len(x))
	exh = 0.5*np.ones(len(x))
	y = np.zeros(len(x))
	eyl = np.zeros(len(x))
	eyh = np.zeros(len(x))
	for i in range(nBins):
		y[i] = eff.Eval(x[i])/eff_noHack.Eval(x[i])
		eyh[i] = math.pow(eff_uncert.GetErrorYhigh(i)/eff_uncert.Eval(x[i]),2) + math.pow(eff.GetErrorYhigh(i)/eff.Eval(x[i]),2) + math.pow(eff_noHack.GetErrorYhigh(i)/eff_noHack.Eval(x[i]),2)
		eyh[i] = math.sqrt(eyh[i])*y[i]
		eyl[i] = math.pow(eff_uncert.GetErrorYlow(i)/eff_uncert.Eval(x[i]),2) + math.pow(eff.GetErrorYlow(i)/eff.Eval(x[i]),2) + math.pow(eff_noHack.GetErrorYlow(i)/eff_noHack.Eval(x[i]),2)
		eyl[i] = math.sqrt(eyl[i])*y[i]
	g_ratio = ROOT.TGraphAsymmErrors(nBins,x,y,exl,exh,eyl,eyh)

	g_canvasLow = ROOT.TGraph(nBins)
	g_canvasLow.Draw("AP")

	g_canvasLow.SetTitle("")
	g_canvasLow.GetYaxis().SetTitle("Correction factor")
	g_canvasLow.GetYaxis().SetTitleSize(30)
	g_canvasLow.GetYaxis().SetTitleFont(43)
	g_canvasLow.GetYaxis().SetTitleOffset(1.7)
	g_canvasLow.GetYaxis().SetLabelSize(25)
	g_canvasLow.GetYaxis().SetLabelFont(43)
	g_canvasLow.GetYaxis().SetNdivisions(6,5,0)
	g_canvasLow.SetMaximum(1.2*g_ratio.GetHistogram().GetMaximum())
	g_canvasLow.SetMinimum(0.8*g_ratio.GetHistogram().GetMinimum())
	if charge == '1Over3': ROOT.gPad.SetLogy()

	# Label bins with selection criteria
	ax = g_canvasLow.GetXaxis()
	ax.Set(nBins,0,nBins)
	ax.SetBinLabel(1,"p_{T}")
	ax.SetBinLabel(2,"#eta")
	ax.SetBinLabel(3,"d_{xy}")
	ax.SetBinLabel(4,"d_{z}")
	ax.SetBinLabel(5,"# hits")
	ax.SetBinLabel(6,"# pixel hits")
	ax.SetBinLabel(7,"#alpha_{max}")
	ax.SetBinLabel(8,"Loose muon")
	ax.SetBinLabel(9,"Isolation")
	ax.SetBinLabel(10,"Global muon")
	ax.SetBinLabel(11,"Timing")
	ax.SetBinLabel(12,"L1")
	ax.SetBinLabel(13,"HLT")
	
	g_canvasLow.GetXaxis().SetTitle(xLabel)
	g_canvasLow.GetXaxis().SetTitleSize(30)
	g_canvasLow.GetXaxis().SetTitleFont(43)
	g_canvasLow.GetXaxis().SetTitleOffset(6)
	g_canvasLow.GetXaxis().SetLabelOffset(0.0075)
	g_canvasLow.GetXaxis().SetLabelSize(25)
	g_canvasLow.GetXaxis().SetLabelFont(43)

	g_canvasLow.Draw("AP")
	c.Update()
	
	g_ratio.Draw("PZ")
	g_ratio.SetLineWidth(2)

	c.SaveAs("plots_selectionEfficiency/selectionEfficiency_M"+mass+"_Q"+charge+".png")
	c.SaveAs("plots_selectionEfficiency/selectionEfficiency_M"+mass+"_Q"+charge+".pdf")

	c.Close()

def getDeltaR(eta1, phi1, eta2, phi2):

	DeltaEta = abs(eta1-eta2)
	DeltaPhi = 0
	if abs(phi1-phi2)>math.pi: DeltaPhi = (2*math.pi)-abs(phi1-phi2)
	else: DeltaPhi = abs(phi1-phi2)
	
	return math.sqrt(pow(DeltaEta,2)+pow(DeltaPhi,2))

if __name__ == '__main__':
	# Define input arguments
	parser = argparse.ArgumentParser(description='')
	parser.add_argument('-m', '--mass', type=str, default='', help='FCP mass')
	parser.add_argument('-q', '--charge', type=str, default='', help='FCP charge')
	args = parser.parse_args()
	
	mass = args.mass
	charge = args.charge
	
	nBins = 13;
	
	# Declare and define graphs
	eff_noHack = ROOT.TGraphAsymmErrors(nBins)
	h_noHack_passed = ROOT.TH1F("", "", nBins, 0, nBins)
	h_noHack_total = ROOT.TH1F("", "", nBins, 0, nBins)
	
	eff = ROOT.TGraphAsymmErrors(nBins)
	h_passed = ROOT.TH1F("", "", nBins, 0, nBins)
	h_total = ROOT.TH1F("", "", nBins, 0, nBins)
	eff_down = ROOT.TGraphAsymmErrors(nBins)
	h_down_passed = ROOT.TH1F("", "", nBins, 0, nBins)
	h_down_total = ROOT.TH1F("", "", nBins, 0, nBins)
	eff_up = ROOT.TGraphAsymmErrors(nBins)
	h_up_passed = ROOT.TH1F("", "", nBins, 0, nBins)
	h_up_total = ROOT.TH1F("", "", nBins, 0, nBins)
	
	fillEfficiency(h_noHack_passed,h_noHack_total,nBins,mass,charge,0)
	fillEfficiency(h_passed,h_total,nBins,mass,charge,1)
	fillEfficiency(h_down_passed,h_down_total,nBins,mass,charge,2)
	fillEfficiency(h_up_passed,h_up_total,nBins,mass,charge,3)
	
	eff_noHack.Divide(h_noHack_passed,h_noHack_total,"cl=0.683 b(1,1) mode")
	eff.Divide(h_passed,h_total,"cl=0.683 b(1,1) mode")
	eff_down.Divide(h_down_passed,h_down_total,"cl=0.683 b(1,1) mode")
	eff_up.Divide(h_up_passed,h_up_total,"cl=0.683 b(1,1) mode")
	
	x = np.linspace(0.5,12.5,13)
	exl = 0.5*np.ones(len(x))
	exh = 0.5*np.ones(len(x))
	y = np.zeros(len(x))
	eyl = np.zeros(len(x))
	eyh = np.zeros(len(x))
	for i in range(len(x)):
		#print(str(eff_down.Eval(x[i]))+" "+str(eff.Eval(x[i]))+" "+str(eff_up.Eval(x[i])))
		y[i] = eff.Eval(x[i])
		eyl[i] = y[i] - eff_down.Eval(x[i])
		eyh[i] = eff_up.Eval(x[i]) - y[i]
	
	eff_uncert = ROOT.TGraphAsymmErrors(nBins,x,y,exl,exh,eyl,eyh)
	
	print("Scaling from the hack after L0 = "+str(eff.Eval(x[0])/eff_noHack.Eval(x[0]))+"\t"+str((eff_up.Eval(x[0])-eff.Eval(x[0]))/eff.Eval(x[0]))+"\t"+str((eff.Eval(x[0])-eff_down.Eval(x[0]))/eff.Eval(x[0])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[1])/eff_noHack.Eval(x[1]))+"\t"+str((eff_up.Eval(x[1])-eff.Eval(x[1]))/eff.Eval(x[1]))+"\t"+str((eff.Eval(x[1])-eff_down.Eval(x[1]))/eff.Eval(x[1])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[2])/eff_noHack.Eval(x[2]))+"\t"+str((eff_up.Eval(x[2])-eff.Eval(x[2]))/eff.Eval(x[2]))+"\t"+str((eff.Eval(x[2])-eff_down.Eval(x[2]))/eff.Eval(x[2])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[3])/eff_noHack.Eval(x[3]))+"\t"+str((eff_up.Eval(x[3])-eff.Eval(x[3]))/eff.Eval(x[3]))+"\t"+str((eff.Eval(x[3])-eff_down.Eval(x[3]))/eff.Eval(x[3])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[4])/eff_noHack.Eval(x[4]))+"\t"+str((eff_up.Eval(x[4])-eff.Eval(x[4]))/eff.Eval(x[4]))+"\t"+str((eff.Eval(x[4])-eff_down.Eval(x[4]))/eff.Eval(x[4])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[5])/eff_noHack.Eval(x[5]))+"\t"+str((eff_up.Eval(x[5])-eff.Eval(x[5]))/eff.Eval(x[5]))+"\t"+str((eff.Eval(x[5])-eff_down.Eval(x[5]))/eff.Eval(x[5])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[6])/eff_noHack.Eval(x[6]))+"\t"+str((eff_up.Eval(x[6])-eff.Eval(x[6]))/eff.Eval(x[6]))+"\t"+str((eff.Eval(x[6])-eff_down.Eval(x[6]))/eff.Eval(x[6])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[7])/eff_noHack.Eval(x[7]))+"\t"+str((eff_up.Eval(x[7])-eff.Eval(x[7]))/eff.Eval(x[7]))+"\t"+str((eff.Eval(x[7])-eff_down.Eval(x[7]))/eff.Eval(x[7])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[8])/eff_noHack.Eval(x[8]))+"\t"+str((eff_up.Eval(x[8])-eff.Eval(x[8]))/eff.Eval(x[8]))+"\t"+str((eff.Eval(x[8])-eff_down.Eval(x[8]))/eff.Eval(x[8])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[9])/eff_noHack.Eval(x[9]))+"\t"+str((eff_up.Eval(x[9])-eff.Eval(x[9]))/eff.Eval(x[9]))+"\t"+str((eff.Eval(x[9])-eff_down.Eval(x[9]))/eff.Eval(x[9])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[10])/eff_noHack.Eval(x[10]))+"\t"+str((eff_up.Eval(x[10])-eff.Eval(x[10]))/eff.Eval(x[10]))+"\t"+str((eff.Eval(x[10])-eff_down.Eval(x[10]))/eff.Eval(x[10])))
	print("Scaling from the hack after L1 = "+str(eff.Eval(x[11])/eff_noHack.Eval(x[11]))+"\t"+str((eff_up.Eval(x[11])-eff.Eval(x[11]))/eff.Eval(x[11]))+"\t"+str((eff.Eval(x[11])-eff_down.Eval(x[11]))/eff.Eval(x[11])))
	print("Scaling from the hack after HLT = "+str(eff.Eval(x[12])/eff_noHack.Eval(x[12]))+"\t"+str((eff_up.Eval(x[12])-eff.Eval(x[12]))/eff.Eval(x[12]))+"\t"+str((eff.Eval(x[12])-eff_down.Eval(x[12]))/eff.Eval(x[12])))
	
	DrawEfficiency(eff_noHack,eff,eff_uncert,nBins,"Selection criterion",mass,charge)
