import ROOT
from ROOT import TFile, RooRealVar, RooArgList, RooDataHist, RooHistPdf, RooFit, RooArgSet, RooMomentMorphFuncND, RooBinning, RooWrapperPdf

# X = mPhi
xMin = 60
xMax = 560
nX   = 50

# Y = mT
yMin = 800
yMax = 3500
nY   = 27

def create_RDH_PDF(name, h):
    xvar = RooRealVar('mPhi','mPhi',xMin,xMax)
    yvar = RooRealVar('mT','mT',yMin,yMax)
    RAL_myvars = RooArgList(xvar,yvar)
    RAS_myvars = RooArgSet(RAL_myvars)
    RDH = RooDataHist('RDH','RDH',RAL_myvars,h)
    RHP = RooHistPdf('RHP','RHP',RAS_myvars,RDH)

    # DEBUG
    print('TH1:        ' + str(h.Integral()))
    print('RDH:        ' + str(RDH.sum(False)))
    print('RDH (true): ' + str(RDH.sum(True)))
    print('RHP:        ' + str(RHP.analyticalIntegral(0)))

    return RDH, RHP

def getBin(binning, val):
    if (val > binning.highBound()):
        return binning.numBins()
    else:
        return binning.binNumber(val)

def Interpolate():
    # Observables
    x = RooRealVar('mPhi','mPhi',xMin,xMax)
    y = RooRealVar('mT','mT',yMin,yMax)
    print(x, type(x))
    # Interpolated observables
    ix = RooRealVar('i_mPhi','i_mPhi',xMin,xMax)
    iy = RooRealVar('i_mT','i_mT',yMin,yMax)
    # Interpolated binning 
    RBx = RooBinning(nX,xMin,xMax)
    RBy = RooBinning(nY,yMin,yMax)
    grid = RooMomentMorphFuncND.Grid2(RBx,RBy)
    # Get the PDFs for the two masses we want to interpolate b/w
    f1 = TFile.Open('THselection_TprimeB-1800-125_18.root')
    f2 = TFile.Open('THselection_TprimeB-1800-175_18.root')
    h1 = f1.Get('MHvsMTH_SR_pass__nominal')
    h2 = f2.Get('MHvsMTH_SR_pass__nominal')
    RDH1, RHP1 = create_RDH_PDF('1',h1)
    RDH2, RHP2 = create_RDH_PDF('2',h2)
    # Add the PDFs to the grid at their X,Y location
    grid.addPdf(RHP1, getBin(RBx, 1800), getBin(RBy, 125))
    grid.addPdf(RHP2, getBin(RBx, 1800), getBin(RBy, 175))
    # Morph 
    morph = RooMomentMorphFuncND('morph', 'morph', RooArgList(ix,iy),RooArgList(x,y), grid, RooMomentMorphFuncND.Linear)
    morph.setPdfMode()
    pdf = RooWrapperPdf('morph_pdf', 'morph_pdf', morph, True)
    # Get the frames for plotting 
    framex = x.frame()
    framey = y.frame()
    # Generate at a given point 
    ix.setVal(1800)
    ix.setVal(150)
    # Plot
    pdf.plotOn(framex, RooFit.LineColor(ROOT.kBlue))
    pdf.plotOn(framey, RooFit.LineColor(ROOT.kBlue))
    c = ROOT.TCanvas('c','c')
    c.Divide(2,1)
    c.cd(1)
    framex.Draw()
    c.cd(2)
    framey.Draw()
    framey.Draw('same')
    c.Print('out.pdf')


Interpolate()