

import ROOT as R
import numpy as np 
import os
import glob

R.EnableImplicitMT(os.cpu_count())  # Use all CPU cores
R.gROOT.SetBatch(True)

## NOTE this only contains the find_chargeMisID function
def find_ChargeMisID(files,compare=True,findChargeMisID=True):
    
    mc16_files = {}
    mc17_files = {}

    outpath="/eos/user/s/seley/Faser/Alma9Validation/SingleMuon/ChargeMisID/"
    outfile = f"{outpath}TrackCharge.root"
    f_out_charge = R.TFile.Open(outfile, "RECREATE")
    for i,(mcid,(file_path,particle,pdg,charge,tracking,symbol)) in enumerate(files.items()):
        if "16" in mcid:
            mc16_files[tracking]=file_path
        else:
            mc17_files[tracking]=file_path

    dicts = {"mupl": [mc16_files, -13, 1, "#mu^{{+}}"], "mumi": [mc17_files, 13, -1,"#mu^{{-}}"]}

    for particle, (files, pdg, charge,symbol) in dicts.items():
        
        df_new, df_old = R.RDataFrame("nt",files["NEW"]),R.RDataFrame("nt",files["OLD"])
        df_new,df_old = define_val(df_new,charge,pdg),define_val(df_old,charge,pdg)
        df_new,df_old = filter_fid(df_new),filter_fid(df_old)
        df_new,df_old = df_new.Filter("NGoodTracks==1"),df_old.Filter("NGoodTracks==1")
        
        if findChargeMisID :
            df_new = df_new.Filter("ChargeMisID")
            df_old = df_old.Filter("ChargeID")

            # Extract values from filtered dataframes
            p_new = df_new.AsNumpy(["t_p_gev"])["t_p_gev"]
            p_old = df_old.AsNumpy(["t_p_gev"])["t_p_gev"]

            # Find common values
            inBoth = [e for e in p_new if e in p_old]

            if len(inBoth) > 0:
                print(len(inBoth))
                print(f"Before filter (New) : {df_new.Count().GetValue()}")
                print(f"Before filter (Old) : {df_old.Count().GetValue()}")

                # Construct the event filter string in a proper ROOT-compatible format
                event_filter = " || ".join([f"t_p_gev == {m}" for m in inBoth])

                # Apply filter only if it's not empty
                if event_filter:
                    df_new = df_new.Filter(event_filter)
                    df_old = df_old.Filter(event_filter)

                print(f"After filter (New) : {df_new.Count().GetValue()}")
                print(f"After filter (Old) : {df_old.Count().GetValue()}")

                # Save snapshots
                df_new.Snapshot("nt", f"/eos/user/s/seley/Faser/Alma9Validation/SingleMuon/ChargeMisID/AL9_misID_NEW{particle}.root")
                df_old.Snapshot("nt", f"/eos/user/s/seley/Faser/Alma9Validation/SingleMuon/ChargeMisID/AL9_correctID_only_OLD{particle}.root")
            
        print(f"Processed: {particle}")

    f_out_charge.Close()
    return