'''
Script to generate an interpolated signal shape between two known masses. Requires that you have already generated the fit parameters for this signal using FitShapes.py

Based off: https://root.cern/doc/v634/rf616__morphing_8py.html
and: https://root-forum.cern.ch/t/interpolation-between-two-2d-pdfs-using-roomomentmorphfuncnd-fails-because-integrating-a-rooabsreallvalue-is-not-allowed/63472
and: https://root-forum.cern.ch/t/using-roomomentmorphfuncnd/58889/6
'''
import ROOT
from ROOT import TFile, RooRealVar, RooArgList, RooDataHist, RooHistPdf, RooFit, RooArgSet, RooMomentMorphFuncND, RooBinning, RooWrapperPdf
import pandas as pd

# X = mPhi
xMin = 60
xMax = 560
nX   = 50
# Y = mT
yMin = 800
yMax = 3500
nY   = 27

# Interpolated params?
txMin = 61
txMax = 599
tyMin = 801
tyMax = 3499

# Number of samples to fill the histos
n_samples = 1000

# Get the location in a RooBinning object based on input values 
def getBin(binning, val):
    if (val >= binning.highBound()):
        return binning.numBins()
    else:
        return binning.binNumber(val)

def interpolate(mT, year, m, m1, m2):
    '''
    Create a signal shape for phi mass "m" between existing phi masses "m1" and "m2"
    '''
    print('------------------------------------------------------------------------------------------')
    print(f'Interpolating signal of mPhi = {m}, between existing signals at mPhi = [{m1},{m2}]')
    print('------------------------------------------------------------------------------------------')
    # Observables
    x = RooRealVar('mPhi','mPhi',xMin,xMax)
    y = RooRealVar('mT','mT',yMin,yMax)
    # interpolated signal parameters 
    ix = RooRealVar('i_mPhi','i_mPhi',txMin,txMax)
    iy = RooRealVar('i_mT','i_mT',tyMin,tyMax)
    # define binning for morphing
    mbx = RooBinning(nX,txMin,txMax)
    mby = RooBinning(nY,tyMin,tyMax)
    grid = RooMomentMorphFuncND.Grid2(mbx, mby)
    # Get the DSCBs for the two signals between which we want to interpolate (m1, m2, where m1 < m < m2)
    f1 = ROOT.TFile.Open(f'./{mT}-{m1}_{year}_ws.root')
    f2 = ROOT.TFile.Open(f'./{mT}-{m2}_{year}_ws.root')
    w1 = f1.Get('w')
    w2 = f2.Get('w')
    pdf1 = w1.pdf('model')
    pdf2 = w2.pdf('model')
    pdf1.SetTitle('model_m1')
    pdf2.SetTitle('model_m2')
    # Add the PDFs to the grid at their respective locations 
    grid.addPdf(pdf1, getBin(mbx,float(m1)), getBin(mby,float(mT)))
    grid.addPdf(pdf2, getBin(mbx,float(m2)), getBin(mby,float(mT)))
    # Create the morphing and add it to the workspace
    morph_func = ROOT.RooMomentMorphFuncND("morph_func", "morph_func", RooArgList(ix,iy), RooArgList(x,y), grid, ROOT.RooMomentMorphFuncND.Linear)
    # Normalizes the morphed object to be a pdf, set it false to prevent warning messages and gain computational speed up
    morph_func.setPdfMode()
    # Creating the morphed pdf
    morph = ROOT.RooWrapperPdf("morph", "morph", morph_func, True)
    # Generate requested signal at a given point 
    ix.setVal(float(m))
    iy.setVal(float(mT))
    # Plot it all 
    framex = x.frame()
    framey = y.frame()
    pdf1.plotOn(framex, ROOT.RooFit.Name(f'{m1}-{mT}'), ROOT.RooFit.LineColor(ROOT.kBlue))
    pdf1.plotOn(framey, ROOT.RooFit.Name(f'{m1}-{mT}'), ROOT.RooFit.LineColor(ROOT.kBlue))
    pdf2.plotOn(framex, ROOT.RooFit.Name(f'{m2}-{mT}'), ROOT.RooFit.LineColor(ROOT.kRed))
    pdf2.plotOn(framey, ROOT.RooFit.Name(f'{m2}-{mT}'), ROOT.RooFit.LineColor(ROOT.kRed))
    morph.plotOn(framex, ROOT.RooFit.Name(f'{m}-{mT}'), ROOT.RooFit.LineColor(ROOT.kGreen))
    morph.plotOn(framey, ROOT.RooFit.Name(f'{m}-{mT}'), ROOT.RooFit.LineColor(ROOT.kGreen))

    # Debug the PDFs/functions being evaluated
    pdf1.Print("T")
    pdf2.Print("T")
    #morph.Print("T")

    l = ROOT.TLegend(0.65,0.73,0.86,0.87)
    l.AddEntry(framex.findObject(f'{m1}-{mT}'),f'({m1},{mT})','L')
    l.AddEntry(framex.findObject(f'{m2}-{mT}'),f'({m2},{mT})','L')
    l.AddEntry(framex.findObject(f'{m}-{mT}'),f'({m},{mT})','L')

    c = ROOT.TCanvas('c','c',1000,800)
    c.Divide(2,1)
    c.cd(1)
    framex.Draw()
    l.Draw()
    c.cd(2)
    framey.Draw()
    l.Draw()
    c.Print('TEST.pdf')

if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('-m', type=str, dest='m',
                        action='store', required=True,
                        help='Requested interpolated signal in form mPhi')
    parser.add_argument('-m1', type=str, dest='m1',
                        action='store', required=True,
                        help='Existing signal of mPhi m1 in form mPhi')
    parser.add_argument('-m2', type=str, dest='m2',
                        action='store', required=True,
                        help='Existing signal of mPhi m2 in form mPhi')
    parser.add_argument('-mT', type=str, dest='mT',
                        action='store', required=True,
                        help='Tprime mass')         
    parser.add_argument('-y', type=str, dest='year',
                        action='store', required=True,
                        help='Run II year')
    args = parser.parse_args()

    interpolate(args.mT, args.year, args.m, args.m1, args.m2)