#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 15 18:00:01 2022

"""

import ROOT
import time
start_time = time.time()

#---------------------------------------------------------------------------------

#constants
m_z = '91.1876'
m_e = '0.00051099895000'
m_mu = '0.0001056583755'

#---------------------------------------------------------------------------------

ROOT.gInterpreter.Declare('''
using Vec_t = const ROOT::RVec<float>&;
using namespace ROOT::Math;
ROOT::VecOps::RVec<PxPyPzEVector> TrueLep_Vecs(string variable, Vec_t Particle_PID, Vec_t Particle_E, Vec_t Particle_Px, Vec_t Particle_Py, Vec_t Particle_Pz, Vec_t Particle_Charge, Vec_t Particle_Status, Vec_t Particle_M1, Vec_t Particle_PT, Vec_t Particle_Eta) {
    
    ROOT::VecOps::RVec<PxPyPzEVector> TrueLep_Vec;
    ROOT::VecOps::RVec<float> TrueLepPT_Vec;
    Long64_t nne = Particle_PID.size();
    
    //filling the vectors with 4-momentum, charge and pT values
    for (int j=0; j<nne; j++){
            if(Particle_Status[j]==1 && fabs(Particle_PID[j])==11 && 
               (fabs(Particle_PID[Particle_M1[j]])==24 || fabs(Particle_PID[Particle_M1[j]])==23 || fabs(Particle_PID[Particle_M1[j]])==9900012 || fabs(Particle_PID[Particle_M1[j]])==15) &&
               Particle_PT[j]>10.0 
               && fabs(Particle_Eta[j])<2.47
               ){
                    if(variable=="TrueLep"){TrueLep_Vec.emplace_back(Particle_Px[j], Particle_Py[j], Particle_Pz[j], Particle_E[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);}
                    if(variable=="TrueEle"){TrueLep_Vec.emplace_back(Particle_Px[j], Particle_Py[j], Particle_Pz[j], Particle_E[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);}
                    }
            if(Particle_Status[j]==1 && fabs(Particle_PID[j])==13 && 
               (fabs(Particle_PID[Particle_M1[j]])==24 || fabs(Particle_PID[Particle_M1[j]])==23 || fabs(Particle_PID[Particle_M1[j]])==9900012 || fabs(Particle_PID[Particle_M1[j]])==15) &&
               Particle_PT[j]>10.0 
               && fabs(Particle_Eta[j])<2.40
               ){
                    if(variable=="TrueLep"){TrueLep_Vec.emplace_back(Particle_Px[j], Particle_Py[j], Particle_Pz[j], Particle_E[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);}
                    if(variable=="TrueMuon"){TrueLep_Vec.emplace_back(Particle_Px[j], Particle_Py[j], Particle_Pz[j], Particle_E[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);}
                    }
            }
                   
    //arranging all vectors in descending pT order
    auto TrueLep_Vec_sorted = Take(TrueLep_Vec, Reverse(Argsort(TrueLepPT_Vec)));
    
    return TrueLep_Vec_sorted;
    
    }
''')

ROOT.gInterpreter.Declare('''
using Vec_t = const ROOT::RVec<float>&;
using namespace ROOT::Math;
ROOT::VecOps::RVec<int> TrueLep_Charges(string variable, Vec_t Particle_PID, Vec_t Particle_E, Vec_t Particle_Px, Vec_t Particle_Py, Vec_t Particle_Pz, Vec_t Particle_Charge, Vec_t Particle_Status, Vec_t Particle_M1, Vec_t Particle_PT, Vec_t Particle_Eta) {

    ROOT::VecOps::RVec<int> TrueLepCharge_Vec;
    ROOT::VecOps::RVec<float> TrueLepPT_Vec;
    Long64_t nne = Particle_PID.size();
    
    //filling the vectors with 4-momentum, charge and pT values
    for (int j=0; j<nne; j++){
            if(Particle_Status[j]==1 && fabs(Particle_PID[j])==11 && 
               (fabs(Particle_PID[Particle_M1[j]])==24 || fabs(Particle_PID[Particle_M1[j]])==23 || fabs(Particle_PID[Particle_M1[j]])==9910012 || fabs(Particle_PID[Particle_M1[j]])==15) &&
               Particle_PT[j]>10.0 
               && fabs(Particle_Eta[j])<2.47
               ){
                    if(variable=="TrueEleCharge"){TrueLepCharge_Vec.emplace_back(Particle_Charge[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);}
                    }
            if(Particle_Status[j]==1 && fabs(Particle_PID[j])==13 && 
               (fabs(Particle_PID[Particle_M1[j]])==24 || fabs(Particle_PID[Particle_M1[j]])==23 || fabs(Particle_PID[Particle_M1[j]])==9910012 || fabs(Particle_PID[Particle_M1[j]])==15) &&
               Particle_PT[j]>10.0 
               && fabs(Particle_Eta[j])<2.40
               ){
                   if(variable=="TrueMuonCharge"){TrueLepCharge_Vec.emplace_back(Particle_Charge[j]); TrueLepPT_Vec.emplace_back(Particle_PT[j]);} 
                   }
            }
                   
    //arranging all vectors in descending pT order
    auto TrueLepCharge_Vec_sorted = Take(TrueLepCharge_Vec, Reverse(Argsort(TrueLepPT_Vec)));
    
    return TrueLepCharge_Vec_sorted;
    
    }
''')


#---------------------------------------------------------------------------------

#define input root file
fileName = "exampleFile.root"
df = ROOT.RDataFrame("Delphes", fileName)

#---------------------------------------------------------------------------------

#Extend branches

#lepton masses
df = df.Define('Electron_Mass', m_e)
df = df.Define('Muon_Mass', m_mu)

#z boson mass
df = df.Define('Z_Mass', m_z)

#define lepton variables
df = df.Define('TrueLep', '''TrueLep_Vecs("TrueLep", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueEle', '''TrueLep_Vecs("TrueEle", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueMuon', '''TrueLep_Vecs("TrueMuon", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueEleCharge', '''TrueLep_Charges("TrueEleCharge", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')
df = df.Define('TrueMuonCharge', '''TrueLep_Charges("TrueMuonCharge", Particle.PID, Particle.E, Particle.Px, Particle.Py, Particle.Pz, Particle.Charge, Particle.Status, Particle.M1, Particle.PT, Particle.Eta)''')

#---------------------------------------------------------------------------------

cut_2l = '(TrueEle.size()+TrueMuon.size()>=2)'
df = df.Filter(cut_2l)
print(df.Count().GetValue())

cut_pT = 'TrueLep[0].Pt()>15 && TrueLep[1].Pt()>10'
df = df.Filter(cut_pT)
print(df.Count().GetValue())

cut_charges = '(TrueEle.size()==2 && abs(TrueEleCharge[0]+TrueEleCharge[1])==0) || '+\
    '(TrueEle.size()==1 && abs(TrueMuonCharge[0])==1) || '+\
    '(TrueEle.size()==0 && abs(TrueMuonCharge[0]+TrueMuonCharge[1])==0)'
df = df.Filter(cut_charges)
print(df.Count().GetValue())

cut_mll = '(TrueLep[0]+TrueLep[1]).M()>12'
df = df.Filter(cut_mll)
print(df.Count().GetValue())

cut_btag = 'Jet.PT[Jet.PT>10 && Jet.BTag==1].size()==0'
df = df.Filter(cut_btag)
print(df.Count().GetValue())

print("total execution time: ", time.time()-start_time)














