# imports
import ROOT as R
import pandas as pd
import numpy as np

# initialise a workspace
ws = R.RooWorkspace("ws", "ws")

# define fit vars
ws.factory("deltat[-15, 15]")
ws.factory("deltaterr[0.02, 3]")

# define reso_vars
ws.factory("muM1[-0.14, -1, 1]")
ws.factory("muT1[-0.38, -10, 10]")
ws.factory("sigmaM1[1.01, 0.5, 2]")
ws.factory("sigmaT1[1.99, 1, 10]")
ws.factory("c1[4.56, 1, 20]")
ws.factory("fTR[0.44, 0, 1]")
ws.factory("fracExp[0.17, 0, 0.5]")

# build resolution function pdfs
ws.factory("RooGaussModel::res_gauss_main(deltat, muM1, sigmaM1, deltaterr, deltaterr)")    # main gaussian
ws.factory("RooGaussModel::res_gauss_tail(deltat, muT1, sigmaT1, deltaterr, deltaterr)")    # core gaussian
ws.factory("RooGExpModel::res_exp_right(deltat, muT1, sigmaT1, c1, deltaterr, deltaterr, deltaterr, False, RooGExpModel::Flipped)")  # right exp
ws.factory("RooGExpModel::res_exp_left(deltat, muT1, sigmaT1, c1, deltaterr, deltaterr, deltaterr, False, RooGExpModel::Normal)")  # left exp

# combine resolution function pdfs
ws.factory("RooAddModel::res_exp({res_exp_right,res_exp_left},{fTR})")
ws.factory("RooAddModel::res_tail({res_exp,res_gauss_tail},{fracExp})")
ws.factory("RooAddModel::res_core({res_tail,res_gauss_main},{frac_tail[0.29, 0, 1]})")   # resolution function done

# fix resolution params
for var_name in ["muM1", "muT1", "sigmaM1", "sigmaT1", "c1", "fTR", "fracExp", "frac_tail"]:
    ws.var(var_name).setConstant(True)

# define FT params
ws.factory("r[0, 1]") # FT dilution
ws.factory("w_slp[-0.511]")
ws.factory("w_intcpt[0.520]")

ws.factory("expr::w('w_slp*r + w_intcpt', w_slp, w_intcpt, r )")
ws.factory("dw[0.015]")
ws.factory("mu[0.000]")

# define CPV params
ws.factory("CCP[0.1, -1, 1]")
ws.factory("SCP[0.2, -1, 1]")

# define cos/sin coeffs for rooBdecay
ws.factory("expr::coshcb('1-dw + mu*(1-2*w)', dw, mu, w)")
ws.factory("expr::coshcbbar('1+dw - mu*(1-2*w)', dw, mu, w)")

ws.factory("expr::sincb('(1-2*w + mu*(1-dw))*SCP', dw, mu, w, SCP)")
ws.factory("expr::sincbbar('(-1+2*w + mu*(1+dw))*SCP', dw, mu, w, SCP)")

ws.factory("expr::coscb('(1-2*w - mu*(1-dw))*CCP', dw, mu, w, CCP)")
ws.factory("expr::coscbbar('(-1+2*w - mu*(1+dw))*CCP', dw, mu, w, CCP)")

# build B decay time pdfs
ws.factory("""
    RooBDecay::sc_pdf_btag(
        deltat,
        B_tau[1.534],
        zero[0],
        coshcb,
        zero,
        coscb,
        sincb,
        bmix_dm[0.507],
        res_core,
        DoubleSided
    )
""")

ws.factory("""
    RooBDecay::sc_pdf_bartag(
        deltat,
        B_tau,
        zero,
        coshcbbar,
        zero,
        coscbbar,
        sincbbar,
        bmix_dm,
        res_core,
        DoubleSided
    )
""")

ws.Print()

# load some data
data_df = pd.DataFrame(R.RDataFrame("data", "mwe_data.root").AsNumpy())

# create RooDataSets
obs_set = R.RooArgSet(ws.var("deltat"), ws.var("deltaterr"), ws.var("r"))
data_btag = R.RooDataSet.from_pandas(data_df.query("tagflav==1"), obs_set)
data_bartag = R.RooDataSet.from_pandas(data_df.query("tagflav==-1"), obs_set)

# create nll
nll_btag = ws.pdf("sc_pdf_btag").createNLL(data_btag, ConditionalObservables=R.RooArgSet(ws.var("r"), ws.var("deltaterr")))
nll_bartag = ws.pdf("sc_pdf_bartag").createNLL(data_bartag, ConditionalObservables=R.RooArgSet(ws.var("r"), ws.var("deltaterr")))

nll_sum = R.RooAddition("nll_sum", "nll_sum", R.RooArgList(nll_btag, nll_bartag))

# try to plot nll
frame = ws.var("SCP").frame(R.RooFit.Title("NLL vs SCP"))
nll_sum.plotOn(frame, ShiftToZero=True)