#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 15 18:00:01 2022

"""

import ROOT
import time
start_time = time.time()

#---------------------------------------------------------------------------------

#constants
m_z = '91.1876'
m_e = '0.00051099895000'
m_mu = '0.0001056583755'

#---------------------------------------------------------------------------------

ROOT.gSystem.CompileMacro("helpers.cpp", "kO")

#---------------------------------------------------------------------------------

#define input root file
df = ROOT.RDataFrame("Delphes", ["out.root"]*20)

#---------------------------------------------------------------------------------

#Extend branches

#lepton masses
df = df.Define('Electron_Mass', m_e)
df = df.Define('Muon_Mass', m_mu)

#z boson mass
df = df.Define('Z_Mass', m_z)

#define lepton variables
df = df.Define('TrueLep', '''TrueLep_Vecs("TrueLep", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueEle', '''TrueLep_Vecs("TrueEle", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueMuon', '''TrueLep_Vecs("TrueMuon", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueEleCharge', '''TrueLep_Charges("TrueEleCharge", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueMuonCharge', '''TrueLep_Charges("TrueMuonCharge", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')

#---------------------------------------------------------------------------------

cut_2l = '(TrueEle.size()+TrueMuon.size()>=2)'
df = df.Filter(cut_2l)
c1 = df.Count()

cut_pT = 'TrueLep[0].Pt()>15 && TrueLep[1].Pt()>10'
df = df.Filter(cut_pT)
c2 = df.Count()

cut_charges = '(TrueEle.size()==2 && abs(TrueEleCharge[0]+TrueEleCharge[1])==0) || '+\
    '(TrueEle.size()==1 && abs(TrueMuonCharge[0])==1) || '+\
    '(TrueEle.size()==0 && abs(TrueMuonCharge[0]+TrueMuonCharge[1])==0)'
df = df.Filter(cut_charges)
c3 = df.Count()

cut_mll = '(TrueLep[0]+TrueLep[1]).M()>12'
df = df.Filter(cut_mll)
c4 = df.Count()

cut_btag = 'Jet.PT[Jet.PT>10 && Jet.BTag==1].size()==0'
df = df.Filter(cut_btag)
c5 = df.Count()

start_event_loop = time.time()
print([c.GetValue() for c in [c1, c2, c3, c4, c5]])
end = time.time()
print("total execution time: ", end-start_time)
print("total event loop time: ", end-start_event_loop)
print("total event loops ", df.GetNRuns())
