import ROOT
from pathlib import Path
from argparse import ArgumentParser


parser = ArgumentParser(description='Make a snapshot of a ROOT file with new columns.')
parser.add_argument('src', nargs='*', type=str, help='Source ROOT file to read.')
parser.add_argument('--tree', '-t', type=str, default='DDTree', help='Name of the TTree to process.')
parser.add_argument('--out', '-o', type=str, default=None, help='Output ROOT file to write.')
parser.add_argument('--filter', '-f', action='store_true', help='Filter to apply to the dataframe.')
args = parser.parse_args()


here = Path(__file__).parent.resolve()
src_path  = [Path(f).resolve() for f in args.src]

if args.out is None:
    file_path = here / Path('../data/snapshot.root')
else:
    file_path = Path(args.out).resolve()
file_path.parent.mkdir(parents=True, exist_ok=True)
    

ROOT.ROOT.EnableImplicitMT(10) # Use 10 threads

rdf = ROOT.RDataFrame(args.tree, map(str, src_path))
ROOT.RDF.Experimental.AddProgressBar(rdf)

for col in rdf.GetColumnNames():
    if rdf.GetColumnType(col).count('<') == 0:
        rdf = rdf.Define(f'{col}_times2', f'{col} * 2')
    elif (rdf.GetColumnType(col).count('RVec<') + rdf.GetColumnType(col).count('vector<')) == 1:
        if 'double' in rdf.GetColumnType(col) or 'float' in rdf.GetColumnType(col):
            rdf = rdf.Define(f'{col}_mean', f'ROOT::VecOps::Mean({col})')
            rdf = rdf.Define(f'{col}_sum', f'ROOT::VecOps::Sum({col})')

            if args.filter: rdf = rdf.Filter(f'{col}_mean > 0')

columns = []
for col in rdf.GetColumnNames():
    coltype = rdf.GetColumnType(col)
    if (coltype.count('RVec<') + coltype.count('vector<')) <= 1:
        columns.append(col)


snap_config = ROOT.RDF.RSnapshotOptions()
snap_config.fVector2RVec = True  # Disable conversion of vectors to RVecs
snap_config.fAutoFlush = 1000


rdf.Snapshot(args.tree, str(file_path), columns, snap_config)