import ROOT
import numpy as np

ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.WARNING)

# Set up the canvas and the histogram
canvas = ROOT.TCanvas("canvas", "2D Gaussian Decay", 800, 600)

# Define the ranges
energy_min = 0
energy_max = 1200
time_min = 0
time_max = 10

# Number of bins for energy and time
n_bins_energy = 12000
n_bins_time = 10

# Create the TH2 histogram
hist2d = ROOT.TH2F("hist2d", "2D Histogram of Gaussians with Exponential Decay",
                   n_bins_energy, energy_min, energy_max,
                   n_bins_time, time_min, time_max)

# Define the parameters for the Gaussians
gaussian_params = [
    {"amplitude": 200, "mean": 100, "sigma": 10},   # First Gaussian
    {"amplitude": 150, "mean": 500, "sigma": 12},   # Second Gaussian
    {"amplitude": 100, "mean": 1150, "sigma": 8}     # Third Gaussian
]

# Define the exponential decay
decay_constant = 1.0  # decay rate

# Linear background parameters
# a = -0.02  # slope in energy (x)
a = -0.002  # slope in energy (x)
b = 0  # slope in time (y)
c = 5     # constant offset

# Fill the histogram with Gaussian peaks decaying exponentially
for i in range(n_bins_energy):
    for j in range(n_bins_time):
        energy = energy_min + (energy_max - energy_min) * i / n_bins_energy
        time = time_min + (time_max - time_min) * j / n_bins_time

        # Initialize the value of the 2D histogram at (energy, time)
        value = 0

        # Sum the contributions of the three Gaussians, each decaying exponentially over time
        for params in gaussian_params:
            amplitude = params["amplitude"]
            mean = params["mean"]
            sigma = params["sigma"]

            # Gaussian function on the energy axis
            gaussian = amplitude * np.exp(-0.5 * ((energy - mean) ** 2) / (sigma ** 2))

            # Exponential decay over time
            exponential_decay = np.exp(-time / decay_constant)

            # Add to the value of the histogram
            value += gaussian * exponential_decay

        # Add the linear background
        background = a * energy + b * time + c
        value += background

        # Set the value in the histogram
        hist2d.SetBinContent(i + 1, j + 1, value)

# Draw the histogram
hist2d.Draw("COLZ")

# Show the canvas
canvas.Update()

try:
    input()
except SyntaxError:
    pass

energy_full = ROOT.RooRealVar("energy", "Energy (keV)", energy_min, energy_max)
time_full = ROOT.RooRealVar("time", "time", time_min, time_max)

roo_datahist = ROOT.RooDataHist("roo_datahist", "RooDataHist from TH1", ROOT.RooArgList(energy_full, time_full), ROOT.RooFit.Import(hist2d))

canvas.Clear()
canvas.Divide(2)
canvas.cd(1)
energy_frame = energy_full.frame()
roo_datahist.plotOn(energy_frame)
energy_frame.Draw()
canvas.cd(2)
time_frame = time_full.frame()
roo_datahist.plotOn(time_frame)
time_frame.Draw()
canvas.Modified()
canvas.Update()


try:
    input("Press Enter to continue")
except SyntaxError:
    pass

energy_full.setRange("range_1", 0, 200)
energy_full.setRange("range_2", 400, 600)
energy_full.setRange("range_3", 1000,1200)

subhist_1 = roo_datahist.reduce(ROOT.RooFit.CutRange("range_1"))

energy = ROOT.RooRealVar("energy", "Energy (keV)", energy_full.getMin("range_1"), energy_full.getMax("range_1"))
time = ROOT.RooRealVar("time", "time", time_full.getMin("range_1"), time_full.getMax("range_1"))


mean1 = ROOT.RooRealVar("mean1", "mean1", 100, 0, 200)
sigma1 = ROOT.RooRealVar("sigma1", "sigma1", 10, 0.1, 20)
gauss1_pdf = ROOT.RooGaussian("gauss1", "gauss1", energy, mean1, sigma1)

gauss1_yield = ROOT.RooRealVar("gauss1_yield", "gauss1_yield", 100, 0, 12000)
bg_yield = ROOT.RooRealVar("bg_yield", "bg_yield", 100, 0, 12000)


a = ROOT.RooRealVar("a", "a", -0.002, -1, 0)
c = ROOT.RooRealVar("c", "c", 5, 0, 100)
bg = ROOT.RooPolynomial("bg", "bg", energy, ROOT.RooArgList(c,a), lowestOrder = 0)

decay_constant = ROOT.RooRealVar("decay_constant", "decay_constant", -1.0, -2.0, 0.0)
exp_decay = ROOT.RooExponential("exp_decay", "exp_decay", time, decay_constant)

pdf_sig = ROOT.RooProdPdf("pdf_sig", "pdf_sig", ROOT.RooArgList(gauss1_pdf, exp_decay))
pdf_1 = ROOT.RooAddPdf("pdf_energy_1", "pdf_energy_1", ROOT.RooArgList(pdf_sig, bg), ROOT.RooArgList(gauss1_yield, bg_yield))

fit_result = pdf_1.fitTo(subhist_1, Save = True)

canvas.Clear()
canvas.Divide(2)
energy_frame = energy.frame()
time_frame = time.frame()

subhist_1.plotOn(energy_frame)
subhist_1.plotOn(time_frame)


pdf_1.plotOn(energy_frame)
pdf_1.plotOn(time_frame)

canvas.cd(1)
energy_frame.Draw()
canvas.cd(2)
time_frame.Draw()

canvas.Modified()
canvas.Update()

try:
    input()
except SyntaxError:
    pass
