import ROOT
from ROOT import TF1
import sys

# Load ROOT and set style
ROOT.gROOT.SetBatch(True)
ROOT.gStyle.SetOptStat(0)

# Create RooDataHist objects
file = ROOT.TFile.Open("test.root")
hMass = file.Get("hMass")
hVn = file.Get("hVn")

mass = ROOT.RooRealVar("mass", "Invariant Mass", 1.7, 2.0)
mass.setRange("fitRange", 1.71, 1.95)
vnvsmass = ROOT.RooRealVar("vnvsmass", "Vn Vs Mass", 1.7, 2.0)
vnvsmass.setRange("fitRange", 1.71, 1.95)
data_hist = ROOT.RooDataHist("data_hist", "Dataset from histogram", ROOT.RooArgList(mass), hMass)

# Define Double-Sided Crystal Ball (DSCB) components for the signal
hCB_pars = file.Get("hCBPars")
sgn_mean = ROOT.RooRealVar("sgn_mean", "Signal Mean", 1.86, 1.85, 1.88)
sgn_width = ROOT.RooRealVar("sgn_width", "Signal Width", 0.013, 0.01, 0.02)
alpha_left = ROOT.RooRealVar("alpha_left", "Left Tail Shape", hCB_pars.GetBinContent(4), 0.001, 10.0)
alpha_right = ROOT.RooRealVar("alpha_right", "Right Tail Shape", hCB_pars.GetBinContent(4), 0.001, 10.0)
n_left = ROOT.RooRealVar("n_left", "Left Exponent", hCB_pars.GetBinContent(5), 1.1, 10.0)
n_right = ROOT.RooRealVar("n_right", "Right Exponent", hCB_pars.GetBinContent(5), 1.1, 10.0)
ds_cb = ROOT.RooCrystalBall("ds_cb", "Double-Sided Crystal Ball", mass, sgn_mean, sgn_width, alpha_left, alpha_right, n_left, n_right)
alpha_left.setConstant(True)
alpha_right.setConstant(True)
n_left.setConstant(True)
n_right.setConstant(True)

# Define Background Polynomial component (pol is implememted as 1 + ...)
c1_bkg = ROOT.RooRealVar("c1_bkg", "c1_bkg", -1, -3, 3)
c2_bkg = ROOT.RooRealVar("c2_bkg", "c2_bkg", 0, -1, 1)
c3_bkg = ROOT.RooRealVar("c3_bkg", "c3_bkg", 0, -1, 1)
background = ROOT.RooPolynomial("background", "Polynomial Background", mass, ROOT.RooArgList(c1_bkg, c2_bkg, c3_bkg))

# Define additional background template from histogram
hTemplate = file.Get("hTemplates")
template_data_hist = ROOT.RooDataHist("template_data_hist", "Template from histogram", ROOT.RooArgList(mass), hTemplate)
template_pdf = ROOT.RooHistPdf("template_pdf", "Template PDF", ROOT.RooArgSet(mass), template_data_hist)

template_fraction = 0.2042
bkg_weight = ROOT.RooRealVar("bkg_weight","S/B inclusive", 0.8, 0., 1)
sgn_wrt_templ = ROOT.RooRealVar("sgn_wrt_templ","corr. bkg. fraction", 1-template_fraction, 0., 1)
sgn_wrt_templ.setConstant(True)

# Composite Model (Signal + Background + Template)
total_pdf_mass = ROOT.RooAddPdf("total_pdf_mass", "Background + (Signal + Template)",
                           ROOT.RooArgList(background, ds_cb, template_pdf),
                           ROOT.RooArgList(bkg_weight, sgn_wrt_templ),
                           True)  # recursive fractions
fit_result = total_pdf_mass.fitTo(data_hist, ROOT.RooFit.Save(), ROOT.RooFit.Range("fitRange"))
fit_result.printMultiline(ROOT.std.cout, 3, True)

# Create frame for plotting
frame = mass.frame()
data_hist.plotOn(frame, ROOT.RooFit.Range("fitRange"))  # Plot within fit range
total_pdf_mass.plotOn(frame, ROOT.RooFit.Range("fitRange"))  # Plot the full model within fit range
total_pdf_mass.plotOn(frame, ROOT.RooFit.Components("ds_cb"), ROOT.RooFit.LineColor(ROOT.kRed), ROOT.RooFit.Range("fitRange"))  # Gaussian
total_pdf_mass.plotOn(frame, ROOT.RooFit.Components("background"), ROOT.RooFit.LineColor(ROOT.kGreen), ROOT.RooFit.LineStyle(ROOT.kDashed), ROOT.RooFit.Range("fitRange"))  # Background
total_pdf_mass.plotOn(frame, ROOT.RooFit.Components("template_pdf"), ROOT.RooFit.LineColor(ROOT.kBlue), ROOT.RooFit.LineStyle(ROOT.kDotted), ROOT.RooFit.Range("fitRange"))  # Template

# Draw the plot
canvas = ROOT.TCanvas("canvas", "Fit Result", 800, 600)
frame.Draw()
canvas.SaveAs("test_fit_result_mass.png")

outfile = ROOT.TFile("/home/mdicosta/FlowDplus/FinalResults/crystalball/roofit/outfile.root", "recreate")
canvas.Write()
outfile.Close()

# Define the dynamic ratios (fractions) for each component (Background, Signal, Template)
vn_pdf_bkg = ROOT.RooGenericPdf("vn_pdf_bkg", "Background / (Background + Signal + Template)", 
                                "@0 / @1", ROOT.RooArgList(background, total_pdf_mass))

vn_pdf_sgn = ROOT.RooGenericPdf("vn_pdf_sgn", "Signal / (Signal + Background + Template)", 
                                "@0 / @1", ROOT.RooArgList(ds_cb, total_pdf_mass))

vn_pdf_templ = ROOT.RooGenericPdf("vn_pdf_templ", "Template / (Template + Background + Signal)", 
                                  "@0 / @1", ROOT.RooArgList(template_pdf, total_pdf_mass))

# Plot each of the PDFs separately
frame_vn = mass.frame()
vn_pdf_bkg.plotOn(frame_vn, ROOT.RooFit.LineColor(ROOT.kRed), ROOT.RooFit.LineStyle(ROOT.kDashed), ROOT.RooFit.Name("vn_pdf_bkg"))
vn_pdf_sgn.plotOn(frame_vn, ROOT.RooFit.LineColor(ROOT.kBlue), ROOT.RooFit.LineStyle(ROOT.kSolid), ROOT.RooFit.Name("vn_pdf_sgn"))
vn_pdf_templ.plotOn(frame_vn, ROOT.RooFit.LineColor(ROOT.kGreen), ROOT.RooFit.LineStyle(ROOT.kDotted), ROOT.RooFit.Name("vn_pdf_templ"))

# Create a canvas
canvas_vn = ROOT.TCanvas("canvas_vn", "Vn PDFs", 800, 600)
frame_vn.Draw()

# Add legend
legend = ROOT.TLegend(0.6, 0.7, 0.88, 0.88)
legend.AddEntry(frame_vn.findObject("vn_pdf_bkg"), "vn_pdf_bkg", "l")
legend.AddEntry(frame_vn.findObject("vn_pdf_sgn"), "vn_pdf_sgn", "l")
legend.AddEntry(frame_vn.findObject("vn_pdf_templ"), "vn_pdf_templ", "l")
legend.Draw()

# Save the plot
canvas_vn.SaveAs("test_vn_pdfs.png")

# Now construct the total model using RooAddPdf
vn_model = ROOT.RooAddPdf("vn_model", "vn_pdf_bkg + vn_pdf_sgn + vn_pdf_templ",
                          ROOT.RooArgList(vn_pdf_bkg, vn_pdf_sgn, vn_pdf_templ),  # PDFs
                          ROOT.RooArgList(bkg_weight, sgn_wrt_templ))
frame_vn_total = mass.frame()
vn_model.plotOn(frame_vn_total, ROOT.RooFit.LineColor(ROOT.kGreen), ROOT.RooFit.LineStyle(ROOT.kDotted), ROOT.RooFit.Name("vn_model"))
canvas_vn_total = ROOT.TCanvas("canvas_vn_total", "Vn PDFs", 800, 600)
frame_vn_total.Draw()
canvas_vn_total.SaveAs("test_vn_total_func.png")