import ROOT
import numpy as np

ROOT.gROOT.SetBatch(True)

x = ROOT.RooRealVar("x", "x", 70.0, 70.0, 120.0)
w = ROOT.RooRealVar("w", "w", 1.0, 0.0, 10.0)  # weights variable

x0 = ROOT.RooRealVar("x0", "x0", 90, 70, 120.0)
sigma = ROOT.RooRealVar("sigmaL", "sigmaL", 5, 1, 100)
alphaL = ROOT.RooRealVar("alphaL", "alphaL", 1.05, 0.1, 10)
alphaR = ROOT.RooRealVar("alphaR", "alphaR", 1.3, 0.1, 10)
nL = ROOT.RooRealVar("nL", "nL", 5, 0.1, 10)
nR = ROOT.RooRealVar("nR", "nR", 2, 0.1, 10)


def reset_parameters():
    x0.setVal(80.0)
    sigma.setVal(3.0)
    alphaL.setVal(1.0)
    alphaR.setVal(1.0)
    nL.setVal(3.0)
    nR.setVal(3.0)

    x0.setError(0.0)
    sigma.setError(0.0)
    alphaL.setError(0.0)
    alphaR.setError(0.0)
    nL.setError(0.0)
    nR.setError(0.0)


cb = ROOT.RooCrystalBall(
    "cb", "cb", x, x0, sigma, sigma, alphaL, nL, alphaR, nR
)

n_events = 100000

data = cb.generate(x, n_events)

reset_parameters()

x_arr = data.to_numpy()["x"]
weights_1 = np.random.uniform(0.008, 2.5, n_events)
weights_2 = np.random.uniform(8e-7, 6e-3, n_events)

data_1 = ROOT.RooDataSet.from_numpy(
    {"x": x_arr, "w": weights_1},
    [x, w],
    name="data1",
    title="data1",
    weight_name="w",
)
data_2 = ROOT.RooDataSet.from_numpy(
    {"x": x_arr, "w": weights_2},
    [x, w],
    name="data2",
    title="data2",
    weight_name="w",
)


def do_fit_and_plot(data, sum_w2_error, asymptotic_error, plot_filename):
    cb.fitTo(
        data,
        SumW2Error=sum_w2_error,
        AsymptoticError=asymptotic_error,
        PrintLevel=-1,
    )

    c1 = ROOT.TCanvas()

    frame = x.frame(Title=plot_filename)
    data.plotOn(frame)
    cb.plotOn(frame)
    cb.paramOn(frame)

    frame.Draw()

    c1.Draw()

    c1.SaveAs(plot_filename + ".png")

    reset_parameters()


do_fit_and_plot(data, False, False, "data")
do_fit_and_plot(data, True, False, "data_sumw2")
do_fit_and_plot(data, False, True, "data_asym")
do_fit_and_plot(data_1, False, False, "data_1")
do_fit_and_plot(data_1, True, False, "data_1_sumw2")
do_fit_and_plot(data_1, False, True, "data_1_asym")
do_fit_and_plot(data_2, False, False, "data_2")
do_fit_and_plot(data_2, True, False, "data_2_sumw2")
do_fit_and_plot(data_2, False, True, "data_2_asym")
