
import os
import sys
import ROOT as R
import pandas as pd
import numpy as np
import json
from tabulate import tabulate
import itertools
from pathlib import Path

# local imports.
# %load_ext autoreload
# %autoreload 2
sys.path.append("/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/utils")

import fit_utils_2 as fu
from useful_funcs import read_yaml

# silence annoying roofit messages
fu.silence_root()

# enable multithreading
R.ROOT.EnableImplicitMT(10)

ws_path = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/workspaces/se_belle_mu.root"
if Path(ws_path).is_file():
    ws = R.TFile(ws_path).Get('se_workspace')
    # ws.Print('t')
else:
    ws = R.RooWorkspace('se_workspace', 'se_workspace')  

reso_ws_path = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/workspaces/dt_belle.root"
reso_ws = R.TFile(reso_ws_path).Get('dt_workspace')
# reso_ws.Print()


for i in range(7):
    pdf = ws.pdf(f"pdfc_{i}")
    if not pdf:
        print(f"pdfc_{i} not found in ws.")
        continue
    reso_ws.Import(pdf, R.RooFit.RecycleConflictNodes(True))

    obj = reso_ws.obj(f"pdfc_{i}")
    if obj:
        print(f"Successfully imported: {obj.ClassName()}::{obj.GetName()}")
    else:
        print(f"Failed to import: pdfc_{i}")


# print("\n### Verifying Saved Parameters in Workspace (resolution) ###")
# for var in  reso_ws.allVars():
#     ws_var = reso_ws.var(var.GetName())  # Get the variable from the workspace
#     print(f"{ws_var.GetName()} = {ws_var.getVal()} +/- {ws_var.getError()}")


param_list = [
    "Bbbar_tau",
    "fTR",
    "muM1",
    "sigmaM1",
    "fracExp",
    "sigmaT1",
    "c1",
    "muT1",
    "frac_tail",    
    # "sig_mbc_cb_mean",
    # "sig_de_js_mu",
    # "cont_gauss_frac",
    # "cont_gauss_mu",
    # "cont_gauss_sigma"
    "muM1_cont",
    "sigmaM1_cont",
    "sigmaM1_bbbar"
]


for param in param_list:
    var = reso_ws.var(param)
    if var:  # This checks if the variable exists in the workspace
        var.setConstant(True)
    else:
        print(f"Warning: Variable '{param}' not found in workspace")

params_json_float = ["qq_dePoly_a0",    "qq_mbcArg_c"]

for param in params_json_float:
    var = ws.var(param)
    if var:
        var.setConstant(False)
    else:
        print(f"Warning: {param} not found in the workspace")


# define vars needed for SC fits

reso_ws.factory("r[0, 1]")
reso_ws.factory("tagflav[B0=1,B0bar=-1]")
reso_ws.factory("CCP[-1, 1]")
reso_ws.factory("SCP[-1, 1]")

bin_edges = [0.0, 0.1, 0.25, 0.45, 0.6, 0.725, 0.875, 1.0]
bin_labels = [0, 1, 2, 3, 4, 5, 6]

for i in range(len(bin_edges) - 1):
    reso_ws.factory(f"r{i}[{bin_edges[i]}, {bin_edges[i+1]}]")

filepath = "/home/belle2/vikasraj/b2kpi/b2bii-tdcpv/ft/gnn_mc.json"

with open(filepath) as fp:
      init_dict = json.load(fp)

dw = []
dw_err = []
mu = []
mu_err = []
w = []
w_err = []
eff = []
eff_err = []

for i in range(7):
    w_val, w_unc = init_dict[f"w_qr_{i}"]
    dw_val, dw_unc = init_dict[f"dw_qr_{i}"]
    mu_val, mu_unc = init_dict[f"mu_qr_{i}"]
    eff_val , eff_unc = init_dict[f"eff_qr_{i}"]

    w.append(w_val)
    w_err.append(w_unc)

    dw.append(dw_val)
    dw_err.append(dw_unc)

    mu.append(mu_val)
    mu_err.append(mu_unc)
    
    eff.append(eff_val)
    eff_err.append(eff_unc)

for i in range(7):
    # Create each variable using factory
    reso_ws.factory(f"w{i}[{w[i]:.6f}, 0, 1]")
    reso_ws.factory(f"dw{i}[{dw[i]:.6f}, -1, 1]")
    reso_ws.factory(f"mu{i}[{mu[i]:.6f}, -1, 1]")
    reso_ws.factory(f"eff{i}[{eff[i]:.6f}, 0, 1]")

    # Set errors
    reso_ws.var(f"w{i}").setError(w_err[i])
    reso_ws.var(f"dw{i}").setError(dw_err[i])
    reso_ws.var(f"mu{i}").setError(mu_err[i])
    reso_ws.var(f"eff{i}").setError(eff_err[i])
    
print(f"{'Bin':<4} {'w ± err':<20} {'dw ± err':<20} {'mu ± err':<20} {'eff ± err':<20}")
print("-" * 84)
for i in range(7):
    print(f"{i:<4} "
          f"{w[i]:.4f} ± {w_err[i]:.4f}   "
          f"{dw[i]:.4f} ± {dw_err[i]:.4f}   "
          f"{mu[i]:.4f} ± {mu_err[i]:.4f}   "
          f"{eff[i]:.4f} ± {eff_err[i]:.4f}")
for i in range(7):

    reso_ws.factory(f"expr::coshcb{i}('1 - dw{i} + mu{i}*(1 - 2*w{i})', dw{i}, mu{i}, w{i})")
    print(reso_ws.obj(f"coshcb{i}"))

    reso_ws.factory(f"expr::coshcbbar{i}('1 + dw{i} - mu{i}*(1 - 2*w{i})', dw{i}, mu{i}, w{i})")
    print(reso_ws.obj(f"coshcbbar{i}"))

    reso_ws.factory(f"expr::sincb{i}('(1 - 2*w{i} + mu{i}*(1 - dw{i}))*SCP', dw{i}, mu{i}, w{i}, SCP)")
    print(reso_ws.obj(f"sincb{i}"))

    reso_ws.factory(f"expr::sincbbar{i}('(-1 + 2*w{i} + mu{i}*(1 + dw{i}))*SCP', dw{i}, mu{i}, w{i}, SCP)")
    print(reso_ws.obj(f"sincbbar{i}"))

    reso_ws.factory(f"expr::coscb{i}('(1 - 2*w{i} - mu{i}*(1 - dw{i}))*CCP', dw{i}, mu{i}, w{i}, CCP)")
    print(reso_ws.obj(f"coscb{i}"))

    reso_ws.factory(f"expr::coscbbar{i}('(-1 + 2*w{i} - mu{i}*(1 + dw{i}))*CCP', dw{i}, mu{i}, w{i}, CCP)")
    print(reso_ws.obj(f"coscbbar{i}"))

for i in range(7):
    res_model = "res_core" if i < 6 else "res_core"
    reso_ws.factory(f"""
        RooBDecay::sc_pdf_btag{i}(
            deltat,
            B_tau[1.534],
            zero[0],
            coshcb{i},
            zero[0],
            coscb{i},
            sincb{i},
            bmix_dm[0.507],
            {res_model},
            DoubleSided
        )
    """)
for i in range(7):
    res_model = "res_core" if i < 6 else "res_core"
    reso_ws.factory(f"""
        RooBDecay::sc_pdf_bartag{i}(
            deltat,
            B_tau,
            zero,
            coshcbbar{i},
            zero,
            coscbbar{i},
            sincbbar{i},
            bmix_dm,
            {res_model},
            DoubleSided
        )
    """)
# Paths
uds_pattern = "/group/belle2/users2022/vikasraj/rootfiles/belle/kspi0_final_merged/uds_*.root",
charm_pattern = "/group/belle2/users2022/vikasraj/rootfiles/belle/kspi0_final_merged/charm_*.root",

# Load into RDataFrame
treename = "tree"

uds_rdf = R.RDataFrame(treename, uds_pattern)
charm_rdf = R.RDataFrame(treename, charm_pattern)

# Convert to Pandas
uds_df = pd.DataFrame(uds_rdf.AsNumpy())
charm_df = pd.DataFrame(charm_rdf.AsNumpy())


qqbar_df = pd.concat([uds_df, charm_df]).query(
    "isSignal_custom !=1 and " 
    "5.24 < mod_mbc and  mod_mbc < 5.29 and "
    "-0.3 < deltaE and  deltaE < 0.3 and "
    "-6 < CSO_BDT_prime and  CSO_BDT_prime < 6" 
).sample(frac=1/6, random_state=1532)

print(qqbar_df.shape[0])


input_base_path = f"/group/belle2/users2022/vikasraj/rootfiles/belle/kspi0_final_merged/bbbar_scaled/*ed_*.root"
treename = "tree"

rdf = R.RDataFrame(treename, input_base_path)

bbbar_df = pd.DataFrame(rdf.AsNumpy()).query(
    "isSignal_loose_ks != 1 and " 
    "5.24 < mod_mbc and  mod_mbc < 5.29 and "
    "-0.3 < deltaE and  deltaE < 0.3 and "
    "-6 < CSO_BDT_prime and  CSO_BDT_prime < 6" 
).sample(frac=1/10, random_state=1532)

print(bbbar_df.shape[0])

input_base_path = f"/group/belle2/users2022/vikasraj/rootfiles/belle/kspi0_final_merged/bbbar_scaled/*ed_*.root"
treename = "tree"

rdf = R.RDataFrame(treename, input_base_path)

scf_df = pd.DataFrame(rdf.AsNumpy()).query(
    "isSignal_loose_ks == 1 and isSignal_custom !=1 and " 
    "5.24 < mod_mbc and  mod_mbc < 5.29 and "
    "-0.3 < deltaE and  deltaE < 0.3 and "
    "-6 < CSO_BDT_prime and  CSO_BDT_prime < 6" 
).sample(frac=1/10, random_state=1532)

print(scf_df.shape[0])


input_base_path = f"/group/belle2/users2022/vikasraj/rootfiles/belle/kspi0_final_merged/bbbar_scaled/*ed_*.root"
treename = "tree"

rdf = R.RDataFrame(treename, input_base_path)

sig_df = pd.DataFrame(rdf.AsNumpy()).query(
    "isSignal_custom ==1 and " 
    "5.24 < mod_mbc and  mod_mbc < 5.29 and "
    "-0.3 < deltaE and  deltaE < 0.3 and "
    "-6 < CSO_BDT_prime and  CSO_BDT_prime < 6" 
).sample(frac=1/10, random_state=1532)
print(sig_df.shape[0])


sig_tot_df = pd.concat([scf_df, sig_df])
sig_tot_num = sig_tot_df.shape[0]


df_tot =  pd.concat([qqbar_df, bbbar_df, sig_tot_df])

# df_tot =  pd.concat([bbbar_df, sig_tot_df])
# df_tot =  pd.concat([qqbar_df, sig_tot_df])
# df_tot =  pd.concat([qqbar_df, bbbar_df])

# df_tot =  sig_tot_df
# df_tot =  qqbar_df
# df_tot =  bbbar_df


td_cut = "(Ks_pip_nSVDHits > 0 and Ks_pim_nSVDHits > 0 and DeltaT > -15 and DeltaT < 15 and DeltaTErr < 3)"
df_td = df_tot.query(td_cut)
df_ti = df_tot.query(f"not {td_cut}")


print(df_td.shape[0])
print(df_ti.shape[0])
print(df_tot.shape[0])


# define the columns to keep
columns = ["DeltaT", "DeltaTErr", "qrGNN", "mod_mbc", "deltaE", "CSO_BDT_prime","CSO_BDT_mu"]
renamed_cols = {key: fu.renamed_cols[key] for key in columns}

# ## Prepare the TD and TI Data

#load data and filter
data_td = df_td[renamed_cols.keys()].rename(columns=renamed_cols)
data_td["r"] = np.abs(data_td["qr"])
data_td["tagflav"] = data_td["qr"]/data_td["r"]

data_ti = df_ti[renamed_cols.keys()].rename(columns=renamed_cols)
data_ti["r"] = np.abs(data_ti["qr"])
data_ti["tagflav"] = data_ti["qr"]/data_ti["r"]

cont_num = qqbar_df.shape[0]
bbbar_num = bbbar_df.shape[0]

print(sig_tot_num, cont_num, bbbar_num, df_tot.shape[0])

# get distributions for r
fu.load_hist_pdf(np.abs(sig_tot_df["qrGNN"]), "r", reso_ws, 7, name="sig_r")
fu.load_hist_pdf(np.abs(bbbar_df["qrGNN"]), "r", reso_ws, 7, name="bbbar_r")
fu.load_hist_pdf(np.abs(qqbar_df["qrGNN"]), "r", reso_ws, 7, name="cont_r")
print(reso_ws.obj('histpdf_sig_r'))
print(reso_ws.obj('histpdf_bbbar_r'))
print(reso_ws.obj('histpdf_cont_r'))

reso_ws.var("n_sig").setVal(500)
reso_ws.var("n_cont").setVal(500)
reso_ws.var("n_bbbar").setVal(262)

reso_ws.var("n_sig").setRange(0, 7000)
reso_ws.var("n_cont").setRange(0, 20000)
reso_ws.var("n_bbbar").setRange(0, 9000)

print(f"bbbar_dt")
print(f"res_core_cont")

def get_bin_fractions(data, varname, bins):
    vals = np.abs(data[varname])
    counts, _ = np.histogram(vals, bins=bins)
    total = len(vals)
    fractions = counts / total if total > 0 else np.zeros(len(counts))
    return fractions

sig_fractions = get_bin_fractions(sig_tot_df.query(td_cut), "qrGNN", bin_edges)
bbbar_fractions = get_bin_fractions(bbbar_df.query(td_cut), "qrGNN", bin_edges)
cont_fractions = get_bin_fractions(qqbar_df.query(td_cut), "qrGNN", bin_edges)

for i, frac in enumerate(sig_fractions):
    reso_ws.factory(f"fractd_sig_bin{i}[{frac}]")
for i, frac in enumerate(bbbar_fractions):
    reso_ws.factory(f"fractd_bbbar_bin{i}[{frac}]")
for i, frac in enumerate(cont_fractions):
    reso_ws.factory(f"fractd_cont_bin{i}[{frac}]")

print(f"{'Bin':<6} {'Signal':<10} {'BBbar':<10} {'Continuum':<10}")
for i in range(len(sig_fractions)):
    print(f"{i:<6} {sig_fractions[i]:<10.3f} {bbbar_fractions[i]:<10.3f} {cont_fractions[i]:<10.3f}")


## Weights in each r bin
sig_fractions_no_cut = get_bin_fractions(sig_tot_df, "qrGNN", bin_edges)
bbbar_fractions_no_cut = get_bin_fractions(bbbar_df, "qrGNN", bin_edges)
cont_fractions_no_cut = get_bin_fractions(qqbar_df, "qrGNN", bin_edges)

for i, frac in enumerate(sig_fractions_no_cut):
    reso_ws.factory(f"w_sig{i}[{frac}]")

for i, frac in enumerate(bbbar_fractions_no_cut):
    reso_ws.factory(f"w_bbbar{i}[{frac}]")

for i, frac in enumerate(cont_fractions_no_cut):
    reso_ws.factory(f"w_cont{i}[{frac}]")

print(f"{'Bin':<6} {'Signal':<10} {'BBbar':<10} {'Continuum':<10}")
for i in range(len(sig_fractions_no_cut)):
    print(f"{i:<6} {sig_fractions_no_cut[i]:<10.3f} {bbbar_fractions_no_cut[i]:<10.3f} {cont_fractions_no_cut[i]:<10.3f}")


# define the pdfs for the signal, bbbar and continuum
for tag, value in {'btag': +1, 'bartag': -1}.items():
    for i in range(7):
        reso_ws.factory(f"expr::normt1_{tag}_{i}('1 - {value} * dw{i} + {value}  * mu{i} * (1 - 2 * w{i})', dw{i}, w{i}, mu{i})")
        reso_ws.factory(f"expr::normt2_{tag}_{i}('({value} * (1 - 2 * w{i}) + mu{i} * (1 - {value} * dw{i})) / (1 + (bmix_dm * bmix_dm) * (B_tau * B_tau))', dw{i}, w{i}, mu{i}, bmix_dm, B_tau)")
        print(reso_ws.obj(f"normt1_{tag}_{i}"))
        print(reso_ws.obj(f"normt2_{tag}_{i}"))

        reso_ws.factory(f"EXPR::sig_q_{tag}_{i}('(normt1_{tag}_{i} - normt2_{tag}_{i}*CCP)', normt1_{tag}_{i}, normt2_{tag}_{i}, CCP)")
        reso_ws.factory(f"EXPR::bbbar_q_{tag}_{i}('normt1_{tag}_{i}', normt1_{tag}_{i})")
        reso_ws.factory(f"EXPR::cont_q_{tag}_{i}('1',)")
        print(reso_ws.obj(f"sig_q_{tag}_{i}"))
        print(reso_ws.obj(f"bbbar_q_{tag}_{i}"))       
        print(reso_ws.obj(f"cont_q_{tag}_{i}"))

        reso_ws.factory(f"PROD::sig_pdf_qr_{tag}_{i}( sig_q_{tag}_{i}|r, histpdf_sig_r )")
        reso_ws.factory(f"PROD::bbbar_pdf_qr_{tag}_{i}(bbbar_q_{tag}_{i}|r, histpdf_bbbar_r )")
        reso_ws.factory(f"PROD::cont_pdf_qr_{tag}_{i}(cont_q_{tag}_{i} | r, histpdf_cont_r)")
        print(reso_ws.obj(f"sig_pdf_qr_{tag}_{i}"))
        print(reso_ws.obj(f"bbbar_pdf_qr_{tag}_{i}"))       
        print(reso_ws.obj(f"cont_pdf_qr_{tag}_{i}"))

        rvar = reso_ws.var("r")

        sigint   = reso_ws.pdf(f"sig_pdf_qr_{tag}_{i}").createIntegral(R.RooArgSet(rvar))
        bbbarint = reso_ws.pdf(f"bbbar_pdf_qr_{tag}_{i}").createIntegral(R.RooArgSet(rvar))
        contint  = reso_ws.pdf(f"cont_pdf_qr_{tag}_{i}").createIntegral(R.RooArgSet(rvar))

        print(f"i={i}, tag={tag} -> sigint={sigint.getVal():.6f}, "
              f"bbbarint={bbbarint.getVal():.6f}, contint={contint.getVal():.6f}")

        reso_ws.Import(sigint,   R.RooFit.RenameVariable(sigint.GetName(),   f"signorm_{tag}_{i}"))
        reso_ws.Import(bbbarint, R.RooFit.RenameVariable(bbbarint.GetName(), f"bbbarnorm_{tag}_{i}"))
        reso_ws.Import(contint,  R.RooFit.RenameVariable(contint.GetName(),  f"contnorm_{tag}_{i}"))

        reso_ws.factory(f"PROD::sig_pdf_dt_r_{tag}_{i}(sc_pdf_{tag}{i}|r, histpdf_sig_r )")
        reso_ws.factory(f"PROD::bbbar_pdf_dt_r_{tag}_{i}(bbbar_dt|r, histpdf_bbbar_r )")
        reso_ws.factory(f"PROD::cont_pdf_dt_r_{tag}_{i}(res_core_cont , histpdf_cont_r)") # No r initially 
        print(reso_ws.obj(f"sig_pdf_dt_qr_{tag}_{i}"))
        print(reso_ws.obj(f"bbbar_pdf_dt_qr_{tag}_{i}"))       
        print(reso_ws.obj(f"cont_pdf_dt_qr_{tag}_{i}"))

        reso_ws.factory(f"PROD::sig_pdf_td_{tag}_{i}( sig_pdf_dt_r_{tag}_{i}, sig_pdf_tot )")
        reso_ws.factory(f"PROD::bbbar_pdf_td_{tag}_{i}( bbbar_pdf_dt_r_{tag}_{i}, bbbar_pdf )")
        reso_ws.factory(f"PROD::cont_pdf_td_{tag}_{i}( cont_pdf_dt_r_{tag}_{i}, cont_pdf)")
        print(reso_ws.obj(f"sig_pdf_td_{tag}_{i}"))
        print(reso_ws.obj(f"bbbar_pdf_td_{tag}_{i}"))       
        print(reso_ws.obj(f"cont_pdf_td_{tag}_{i}"))

        reso_ws.factory(f"PROD::sig_pdf_ti_{tag}_{i}( sig_pdf_qr_{tag}_{i}, sig_pdf_tot )")
        reso_ws.factory(f"PROD::bbbar_pdf_ti_{tag}_{i}( bbbar_pdf_qr_{tag}_{i}, bbbar_pdf )")
        reso_ws.factory(f"PROD::cont_pdf_ti_{tag}_{i}( cont_pdf_qr_{tag}_{i}, cont_pdf)")
        print(reso_ws.obj(f"sig_pdf_ti_{tag}_{i}"))
        print(reso_ws.obj(f"bbbar_pdf_ti_{tag}_{i}"))       
        print(reso_ws.obj(f"cont_pdf_ti_{tag}_{i}"))

 

for tag, value in {'btag': +1, 'bartag': -1}.items():
    for i in range(7):
            reso_ws.factory(
                f"expr::intFrac_sig_{tag}_{i}("
                f"'signorm_{tag}_{i}/(signorm_btag_{i} + signorm_bartag_{i})', "
                f"signorm_btag_{i}, signorm_bartag_{i})"
            )
            print(reso_ws.obj(f"intFrac_sig_{tag}_{i}"))

            reso_ws.factory(
                f"expr::intFrac_bbbar_{tag}_{i}("
                f"'bbbarnorm_{tag}_{i}/(bbbarnorm_btag_{i} + bbbarnorm_bartag_{i})', "
                f"bbbarnorm_btag_{i}, bbbarnorm_bartag_{i})"
            )
            print(reso_ws.obj(f"intFrac_bbbar_{tag}_{i}"))

            reso_ws.factory(
                f"expr::intFrac_cont_{tag}_{i}("
                f"'contnorm_{tag}_{i}/(contnorm_btag_{i} + contnorm_bartag_{i})', "
                f"contnorm_btag_{i}, contnorm_bartag_{i})"
            )
            print(reso_ws.obj(f"intFrac_cont_{tag}_{i}"))

            # Evaluate and round each to 2 decimals
            sig_val   = round(reso_ws.obj(f"intFrac_sig_{tag}_{i}").getVal(),   2)
            bbbar_val = round(reso_ws.obj(f"intFrac_bbbar_{tag}_{i}").getVal(), 2)
            cont_val  = round(reso_ws.obj(f"intFrac_cont_{tag}_{i}").getVal(),  2)

            # Store as RooRealVars (correct syntax!)
            reso_ws.factory(f"intFrac_sig_{tag}_{i}_rounded[{sig_val}]")
            reso_ws.factory(f"intFrac_bbbar_{tag}_{i}_rounded[{bbbar_val}]")
            reso_ws.factory(f"intFrac_cont_{tag}_{i}_rounded[{cont_val}]")

            print(f"Stored rounded values for {tag}_{i}: sig={sig_val}, bbbar={bbbar_val}, cont={cont_val}")
            print(reso_ws.obj(f"intFrac_sig_{tag}_{i}_rounded"))

            
            reso_ws.factory(f"expr::extterm_sig_td_{tag}_{i}('n_sig*w_sig{i}*fractd_sig_bin{i}*intFrac_sig_{tag}_{i}_rounded', n_sig, w_sig{i},intFrac_sig_{tag}_{i}_rounded,fractd_sig_bin{i})")
            reso_ws.factory(f"expr::extterm_bbbar_td_{tag}_{i}('n_bbbar*w_bbbar{i}*fractd_bbbar_bin{i}*intFrac_bbbar_{tag}_{i}_rounded', n_bbbar, w_bbbar{i},intFrac_bbbar_{tag}_{i}_rounded, fractd_bbbar_bin{i})")
            reso_ws.factory(f"expr::extterm_cont_td_{tag}_{i}('n_cont*w_cont{i}*fractd_cont_bin{i}*intFrac_cont_{tag}_{i}_rounded', n_cont, w_cont{i}, intFrac_cont_{tag}_{i}_rounded, fractd_cont_bin{i})")

            reso_ws.factory(f"expr::extterm_sig_ti_{tag}_{i}('n_sig*w_sig{i}*(1-fractd_sig_bin{i})*intFrac_sig_{tag}_{i}_rounded', n_sig, w_sig{i},intFrac_sig_{tag}_{i}_rounded,fractd_sig_bin{i})")
            reso_ws.factory(f"expr::extterm_bbbar_ti_{tag}_{i}('n_bbbar*w_bbbar{i}*(1-fractd_bbbar_bin{i})*intFrac_bbbar_{tag}_{i}_rounded', n_bbbar, w_bbbar{i},intFrac_bbbar_{tag}_{i}_rounded, fractd_bbbar_bin{i})")
            reso_ws.factory(f"expr::extterm_cont_ti_{tag}_{i}('n_cont*w_cont{i}*(1-fractd_cont_bin{i})*intFrac_cont_{tag}_{i}_rounded', n_cont, w_cont{i}, intFrac_cont_{tag}_{i}_rounded, fractd_cont_bin{i})")
            
            reso_ws.factory(f"SUM::pdf_full_td_{tag}_{i}(extterm_bbbar_td_{tag}_{i}*bbbar_pdf_td_{tag}_{i}, extterm_sig_td_{tag}_{i}*sig_pdf_td_{tag}_{i}, extterm_cont_td_{tag}_{i}*cont_pdf_td_{tag}_{i})")
            reso_ws.factory(f"SUM::pdf_full_ti_{tag}_{i}(extterm_bbbar_ti_{tag}_{i}*bbbar_pdf_ti_{tag}_{i}, extterm_sig_ti_{tag}_{i}*sig_pdf_ti_{tag}_{i}, extterm_cont_ti_{tag}_{i}*cont_pdf_ti_{tag}_{i})")


        
    # reso_ws.factory(f"expr::extterm_sig_td_{tag}('n_sig*fractd_sig*(signorm_{tag}/(signorm_btag + signorm_bartag))', n_sig, fractd_sig, signorm_btag, signorm_bartag)")
    # reso_ws.factory(f"expr::extterm_bbbar_td_{tag}('n_bbbar*fractd_bbbar*(bbbarnorm_{tag}/(bbbarnorm_btag + bbbarnorm_bartag))', n_bbbar, fractd_bbbar, bbbarnorm_btag, bbbarnorm_bartag)")
    # reso_ws.factory(f"expr::extterm_cont_td_{tag}('n_cont*fractd_cont*(contnorm_{tag}/(contnorm_btag + contnorm_bartag))', n_cont, fractd_cont, contnorm_btag, contnorm_bartag)")

    # reso_ws.factory(f"expr::extterm_sig_ti_{tag}('n_sig*(1-fractd_sig)*(signorm_{tag}/(signorm_btag + signorm_bartag))', n_sig, fractd_sig, signorm_btag, signorm_bartag)")
    # reso_ws.factory(f"expr::extterm_bbbar_ti_{tag}('n_bbbar*(1-fractd_bbbar)*(bbbarnorm_{tag}/(bbbarnorm_btag + bbbarnorm_bartag))', n_bbbar, fractd_bbbar, bbbarnorm_btag, bbbarnorm_bartag)")
    # reso_ws.factory(f"expr::extterm_cont_ti_{tag}('n_cont*(1-fractd_cont)*(contnorm_{tag}/(contnorm_btag + contnorm_bartag))', n_cont, fractd_cont, contnorm_btag, contnorm_bartag)")


    # reso_ws.factory(f"ExtendPdf::pdf_full_td_{tag}(cont_pdf_td_{tag}, extterm_cont_td_{tag})")
    # reso_ws.factory(f"ExtendPdf::pdf_full_ti_{tag}(cont_pdf_ti_{tag}, extterm_cont_ti_{tag})")

    # reso_ws.factory(f"ExtendPdf::pdf_full_td_{tag}(sig_pdf_td_{tag}, extterm_sig_td_{tag})")
    # reso_ws.factory(f"ExtendPdf::pdf_full_ti_{tag}(sig_pdf_ti_{tag}, extterm_sig_ti_{tag})")

    # reso_ws.factory(f"ExtendPdf::pdf_full_td_{tag}(bbbar_pdf_td_{tag}, extterm_bbbar_td_{tag})")
    # reso_ws.factory(f"ExtendPdf::pdf_full_ti_{tag}(bbbar_pdf_ti_{tag}, extterm_bbbar_ti_{tag})")


    # reso_ws.factory(f"SUM::pdf_full_td_{tag}(extterm_bbbar_td_{tag}*bbbar_pdf_td_{tag}, extterm_cont_td_{tag}*cont_pdf_td_{tag})")
    # reso_ws.factory(f"SUM::pdf_full_ti_{tag}(extterm_bbbar_ti_{tag}*bbbar_pdf_ti_{tag}, extterm_cont_ti_{tag}*cont_pdf_ti_{tag})")

    # reso_ws.factory(f"SUM::pdf_full_td_{tag}(extterm_sig_td_{tag}*sig_pdf_td_{tag}, extterm_cont_td_{tag}*cont_pdf_td_{tag})")
    # # reso_ws.factory(f"SUM::pdf_full_ti_{tag}(extterm_sig_ti_{tag}*sig_pdf_ti_{tag}, extterm_cont_ti_{tag}*cont_pdf_ti_{tag})")

    # reso_ws.factory(f"SUM::pdf_full_td_{tag}(extterm_bbbar_td_{tag}*bbbar_pdf_td_{tag}, extterm_sig_td_{tag}*sig_pdf_td_{tag})")
    # reso_ws.factory(f"SUM::pdf_full_ti_{tag}(extterm_bbbar_ti_{tag}*bbbar_pdf_ti_{tag}, extterm_sig_ti_{tag}*sig_pdf_ti_{tag})")


            print(reso_ws.obj(f"pdf_full_td_{tag}_{i}"))
            print(reso_ws.obj(f"pdf_full_ti_{tag}_{i}"))



# bbbar_cnstr = R.RooGaussian("bbbar_cnstr","bbbar_cnstr",reso_ws.var('n_bbbar'),R.RooFit.RooConst(bbbar_num),R.RooFit.RooConst(30)) 
# getattr(reso_ws, "import")(bbbar_cnstr, R.RooFit.RecycleConflictNodes(), R.RooFit.Silence())

# print(reso_ws.obj("bbbar_cnstr"))

# reso_ws.factory(f"PROD::pdf_full_td_btag_cnstr(pdf_full_td_btag,  bbbar_cnstr)")  # BBbar Gaussian constraint
# reso_ws.factory(f"PROD::pdf_full_td_bartag_cnstr(pdf_full_td_bartag,  bbbar_cnstr)")  # BBbar Gaussian constraint
# reso_ws.factory(f"PROD::pdf_full_ti_btag_cnstr(pdf_full_ti_btag,a  bbbar_cnstr)")  # BBbar Gaussian constraint
# reso_ws.factory(f"PROD::pdf_full_ti_bartag_cnstr(pdf_full_ti_bartag,  bbbar_cnstr)")  # BBbar Gaussian constraint

# for tag, value in {'btag': +1, 'bartag': -1}.items():
#                     print(reso_ws.obj(f"pdf_full_td_{tag}_cnstr"))

data_td["r_bin"] = pd.cut(data_td["r"], bins=bin_edges, labels=bin_labels, include_lowest=True)
data_ti["r_bin"] = pd.cut(data_ti["r"], bins=bin_edges, labels=bin_labels, include_lowest=True)

# Create separate DataFrames for each r bin
data_td_bins = {f"data_td{bin_label}": data_td[data_td["r_bin"] == bin_label].copy() for bin_label in bin_labels}
data_ti_bins = {f"data_ti{bin_label}": data_ti[data_ti["r_bin"] == bin_label].copy() for bin_label in bin_labels}

# Assign back to individual variables for convenience
data_td0 = data_td_bins["data_td0"]
data_td1 = data_td_bins["data_td1"]
data_td2 = data_td_bins["data_td2"]
data_td3 = data_td_bins["data_td3"]
data_td4 = data_td_bins["data_td4"]
data_td5 = data_td_bins["data_td5"]
data_td6 = data_td_bins["data_td6"]

data_ti0 = data_ti_bins["data_ti0"]
data_ti1 = data_ti_bins["data_ti1"]
data_ti2 = data_ti_bins["data_ti2"]
data_ti3 = data_ti_bins["data_ti3"]
data_ti4 = data_ti_bins["data_ti4"]
data_ti5 = data_ti_bins["data_ti5"]
data_ti6 = data_ti_bins["data_ti6"]

for label, df in data_td_bins.items():
    print(f"{label}: {len(df)} entries")

for label, df in data_ti_bins.items():
    print(f"{label}: {len(df)} entries")


obs_td = {
    reso_ws.var("deltat"), 
    reso_ws.var("deltaterr"), 
    reso_ws.var("r"),
    reso_ws.var("mod_mbc"), 
    reso_ws.var("deltae"),
    reso_ws.var("csobdtmu"),
    # reso_ws.var("tagflav")
}
obs_ti = {
    reso_ws.var("r"), 
    reso_ws.var("mod_mbc"), 
    reso_ws.var("deltae"),
    reso_ws.var("csobdtmu")
}


data_set_td_btag_0 = R.RooDataSet.from_pandas(data_td0.query('tagflav == 1'), obs_td)
data_set_td_btag_1 = R.RooDataSet.from_pandas(data_td1.query('tagflav == 1'), obs_td)
data_set_td_btag_2 = R.RooDataSet.from_pandas(data_td2.query('tagflav == 1'), obs_td)
data_set_td_btag_3 = R.RooDataSet.from_pandas(data_td3.query('tagflav == 1'), obs_td)
data_set_td_btag_4 = R.RooDataSet.from_pandas(data_td4.query('tagflav == 1'), obs_td)
data_set_td_btag_5 = R.RooDataSet.from_pandas(data_td5.query('tagflav == 1'), obs_td)
data_set_td_btag_6 = R.RooDataSet.from_pandas(data_td6.query('tagflav == 1'), obs_td)

data_set_td_bartag_0 = R.RooDataSet.from_pandas(data_td0.query('tagflav == -1'), obs_td)
data_set_td_bartag_1 = R.RooDataSet.from_pandas(data_td1.query('tagflav == -1'), obs_td)
data_set_td_bartag_2 = R.RooDataSet.from_pandas(data_td2.query('tagflav == -1'), obs_td)
data_set_td_bartag_3 = R.RooDataSet.from_pandas(data_td3.query('tagflav == -1'), obs_td)
data_set_td_bartag_4 = R.RooDataSet.from_pandas(data_td4.query('tagflav == -1'), obs_td)
data_set_td_bartag_5 = R.RooDataSet.from_pandas(data_td5.query('tagflav == -1'), obs_td)
data_set_td_bartag_6 = R.RooDataSet.from_pandas(data_td6.query('tagflav == -1'), obs_td)

data_set_ti_btag_0 = R.RooDataSet.from_pandas(data_ti0.query("tagflav==1"), obs_ti)
data_set_ti_btag_1 = R.RooDataSet.from_pandas(data_ti1.query("tagflav==1"), obs_ti)
data_set_ti_btag_2 = R.RooDataSet.from_pandas(data_ti2.query("tagflav==1"), obs_ti)
data_set_ti_btag_3 = R.RooDataSet.from_pandas(data_ti3.query("tagflav==1"), obs_ti)
data_set_ti_btag_4 = R.RooDataSet.from_pandas(data_ti4.query("tagflav==1"), obs_ti)
data_set_ti_btag_5 = R.RooDataSet.from_pandas(data_ti5.query("tagflav==1"), obs_ti)
data_set_ti_btag_6 = R.RooDataSet.from_pandas(data_ti6.query("tagflav==1"), obs_ti)

data_set_ti_bartag_0 = R.RooDataSet.from_pandas(data_ti0.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_1 = R.RooDataSet.from_pandas(data_ti1.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_2 = R.RooDataSet.from_pandas(data_ti2.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_3 = R.RooDataSet.from_pandas(data_ti3.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_4 = R.RooDataSet.from_pandas(data_ti4.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_5 = R.RooDataSet.from_pandas(data_ti5.query("tagflav==-1"), obs_ti)
data_set_ti_bartag_6 = R.RooDataSet.from_pandas(data_ti6.query("tagflav==-1"), obs_ti)



extra_params = [
    "dw0", "dw1", "dw2", "dw3", "dw4", "dw5", "dw6",
    "mu0", "mu1", "mu2", "mu3", "mu4", "mu5", "mu6",
    "w0", "w1", "w2", "w3", "w4", "w5", "w6","Bbbar_tau"
]

for parname in extra_params:
    par = reso_ws.var(parname)
    if par:
        par.setConstant(True)
    else:
        print(f"Warning: Parameter {parname} not found in workspace")

data= df_tot[renamed_cols.keys()].rename(columns=renamed_cols)
data["r"] = np.abs(data["qr"])
data["tagflav"] = data["qr"]/data["r"]

dataset = R.RooDataSet.from_pandas(data, obs_td)

for i in range(7):
    for mode in ["td", "ti"]:
        df = globals()[f"data_{mode}{i}"]
        pos = (df["tagflav"] == 1).sum()
        neg = (df["tagflav"] == -1).sum()
        diff = pos - neg
        print(f"data_{mode}_{i}:  +1 = {pos},  -1 = {neg},  diff = {diff}")
    print("-" * 50)


print("tagflav == +1:", (data["tagflav"] == 1).sum()) 
print("tagflav == -1:", (data["tagflav"] == -1).sum())

nll_td_btag_0 = reso_ws.pdf("pdf_full_td_btag_0_cnstr").createNLL(data_set_td_btag_0, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_1 = reso_ws.pdf("pdf_full_td_btag_1_cnstr").createNLL(data_set_td_btag_1, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_2 = reso_ws.pdf("pdf_full_td_btag_2_cnstr").createNLL(data_set_td_btag_2, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_3 = reso_ws.pdf("pdf_full_td_btag_3_cnstr").createNLL(data_set_td_btag_3, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_4 = reso_ws.pdf("pdf_full_td_btag_4_cnstr").createNLL(data_set_td_btag_4, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_5 = reso_ws.pdf("pdf_full_td_btag_5_cnstr").createNLL(data_set_td_btag_5, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_btag_6 = reso_ws.pdf("pdf_full_td_btag_6_cnstr").createNLL(data_set_td_btag_6, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)

nll_td_bartag_0 = reso_ws.pdf("pdf_full_td_bartag_0_cnstr").createNLL(data_set_td_bartag_0, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_1 = reso_ws.pdf("pdf_full_td_bartag_1_cnstr").createNLL(data_set_td_bartag_1, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_2 = reso_ws.pdf("pdf_full_td_bartag_2_cnstr").createNLL(data_set_td_bartag_2, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_3 = reso_ws.pdf("pdf_full_td_bartag_3_cnstr").createNLL(data_set_td_bartag_3, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_4 = reso_ws.pdf("pdf_full_td_bartag_4_cnstr").createNLL(data_set_td_bartag_4, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_5 = reso_ws.pdf("pdf_full_td_bartag_5_cnstr").createNLL(data_set_td_bartag_5, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)
nll_td_bartag_6 = reso_ws.pdf("pdf_full_td_bartag_6_cnstr").createNLL(data_set_td_bartag_6, ConditionalObservables=set([reso_ws.var("deltaterr")]), Extended=True)

#########################################################################

nll_ti_btag_0 = reso_ws.pdf("pdf_full_ti_btag_0_cnstr").createNLL(data_set_ti_btag_0, Extended=True)
nll_ti_btag_1 = reso_ws.pdf("pdf_full_ti_btag_1_cnstr").createNLL(data_set_ti_btag_1, Extended=True)
nll_ti_btag_2 = reso_ws.pdf("pdf_full_ti_btag_2_cnstr").createNLL(data_set_ti_btag_2, Extended=True)
nll_ti_btag_3 = reso_ws.pdf("pdf_full_ti_btag_3_cnstr").createNLL(data_set_ti_btag_3, Extended=True)
nll_ti_btag_4 = reso_ws.pdf("pdf_full_ti_btag_4_cnstr").createNLL(data_set_ti_btag_4, Extended=True)
nll_ti_btag_5 = reso_ws.pdf("pdf_full_ti_btag_5_cnstr").createNLL(data_set_ti_btag_5, Extended=True)
nll_ti_btag_6 = reso_ws.pdf("pdf_full_ti_btag_6_cnstr").createNLL(data_set_ti_btag_6, Extended=True)

nll_ti_bartag_0 = reso_ws.pdf("pdf_full_ti_bartag_0_cnstr").createNLL(data_set_ti_bartag_0, Extended=True)
nll_ti_bartag_1 = reso_ws.pdf("pdf_full_ti_bartag_1_cnstr").createNLL(data_set_ti_bartag_1, Extended=True)
nll_ti_bartag_2 = reso_ws.pdf("pdf_full_ti_bartag_2_cnstr").createNLL(data_set_ti_bartag_2, Extended=True)
nll_ti_bartag_3 = reso_ws.pdf("pdf_full_ti_bartag_3_cnstr").createNLL(data_set_ti_bartag_3, Extended=True)
nll_ti_bartag_4 = reso_ws.pdf("pdf_full_ti_bartag_4_cnstr").createNLL(data_set_ti_bartag_4, Extended=True)
nll_ti_bartag_5 = reso_ws.pdf("pdf_full_ti_bartag_5_cnstr").createNLL(data_set_ti_bartag_5, Extended=True)
nll_ti_bartag_6 = reso_ws.pdf("pdf_full_ti_bartag_6_cnstr").createNLL(data_set_ti_bartag_6, Extended=True)

nll_td_btag_sum = R.RooAddition("nll_td_btag_sum", "nll_td_btag_sum", 
                                R.RooArgList({nll_td_btag_0, nll_td_btag_1, nll_td_btag_2, nll_td_btag_3, nll_td_btag_4, nll_td_btag_5, nll_td_btag_6})
                               )
nll_td_bartag_sum = R.RooAddition("nll_td_bartag_sum", "nll_td_bartag_sum", 
                                R.RooArgList({nll_td_bartag_0, nll_td_bartag_1, nll_td_bartag_2, nll_td_bartag_3, nll_td_bartag_4, nll_td_bartag_5, nll_td_bartag_6})
                               )
nll_ti_btag_sum = R.RooAddition("nll_ti_btag_sum", "nll_ti_btag_sum", 
                                R.RooArgList({nll_ti_btag_0, nll_ti_btag_1, nll_ti_btag_2, nll_ti_btag_3, nll_ti_btag_4, nll_ti_btag_5, nll_ti_btag_6})
                               )
nll_ti_bartag_sum = R.RooAddition("nll_ti_bartag_sum", "nll_ti_bartag_sum", 
                                R.RooArgList({nll_ti_bartag_0, nll_ti_bartag_1, nll_ti_bartag_2, nll_ti_bartag_3, nll_ti_bartag_4, nll_ti_bartag_5, nll_ti_bartag_6})
                               )
nll_sum = R.RooAddition("nll_sum","nll_sum",R.RooArgList(nll_td_btag_sum, nll_td_bartag_sum, nll_ti_btag_sum, nll_ti_bartag_sum))

print(nll_sum)

min = R.RooMinimizer(nll_sum)
min.optimizeConst(True)
min.setEps(1)
min.setOffsetting(True)
min.setMinimizerType("Minuit2")
min.migrad()
min.hesse()
m_result_hesse=min.save("hesse_result")


def get_val_err(varname):
    var = m_result_hesse.floatParsFinal().find(varname)
    return var.getVal(), var.getError()

# Get values
n_sig, n_sig_err = get_val_err("n_sig")
n_cont, n_cont_err = get_val_err("n_cont")
n_bbbar, n_bbbar_err = get_val_err("n_bbbar")
CCP, CCP_err = get_val_err("CCP")
SCP, SCP_err = get_val_err("SCP")


# --- Table: expected vs fitted values and deviation ---
table_data = []

# Helper function
def row(category, expected, fit_val, fit_err):
    deviation = (fit_val - expected) / fit_err if fit_err != 0 else float('inf')
    return [
        category,
        round(expected, 2),
        f"{fit_val:.2f} ± {fit_err:.2f}",
        f"{deviation:.2f} σ"
    ]

table_data.append(row("Signal + SCF", sig_tot_num, n_sig, n_sig_err))
table_data.append(row("Continuum", cont_num, n_cont, n_cont_err))
table_data.append(row("bbbar", bbbar_num, n_bbbar, n_bbbar_err))
table_data.append(row("CCP", 0, CCP, CCP_err))
table_data.append(row("SCP", 0, SCP, SCP_err))

# --- Print Table ---
print(tabulate(
    table_data,
    headers=[
        "Category",
        "TM expected",
        "Fit Result ( With CSO_BDT_mu )",
        "Deviation (σ)"
    ],
    tablefmt="grid",
    colalign=("left", "center", "center", "center")
))

plot_dir = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/plots/combined/sc"
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)


dataset_btag_qr = [
    data_set_td_btag_0, data_set_td_btag_1, data_set_td_btag_2,
    data_set_td_btag_3, data_set_td_btag_4, data_set_td_btag_5,
    data_set_td_btag_6
]
dataset_bartag_qr = [
    data_set_td_bartag_0, data_set_td_bartag_1, data_set_td_bartag_2,
    data_set_td_bartag_3, data_set_td_bartag_4, data_set_td_bartag_5,
    data_set_td_bartag_6
]

for rbin in range(7):
    pdf_btag = reso_ws.pdf(f"pdf_full_td_btag_{rbin}_cnstr")
    pdf_bartag = reso_ws.pdf(f"pdf_full_td_bartag_{rbin}_cnstr")

    fu.plot_sc_1(
        pdfs=[pdf_btag, pdf_bartag],
        datasets=[dataset_btag_qr[rbin], dataset_bartag_qr[rbin]],
        var=reso_ws.var("deltat"),
        cond_vars=[
            reso_ws.var("deltaterr"),
            reso_ws.var("mod_mbc"),
            reso_ws.var("deltae"),
            reso_ws.var("csobdtmu"),
            reso_ws.var("r")
        ],
        plotpath=f"{plot_dir}/sc_combined_r{rbin}.png",

    )


dataset_0 = R.RooDataSet.from_pandas(pd.concat([data_td0, data_ti0]), obs_ti)
dataset_1 = R.RooDataSet.from_pandas(pd.concat([data_td1, data_ti1]), obs_ti)
dataset_2 = R.RooDataSet.from_pandas(pd.concat([data_td2, data_ti2]), obs_ti)
dataset_3 = R.RooDataSet.from_pandas(pd.concat([data_td3, data_ti3]), obs_ti)
dataset_4 = R.RooDataSet.from_pandas(pd.concat([data_td4, data_ti4]), obs_ti)
dataset_5 = R.RooDataSet.from_pandas(pd.concat([data_td5, data_ti5]), obs_ti)
dataset_6 = R.RooDataSet.from_pandas(pd.concat([data_td6, data_ti6]), obs_ti)


plot_dir = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/plots/combined/deltae"
os.makedirs(plot_dir, exist_ok=True)

plot_dir = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/plots/combined/mod_mbc"
os.makedirs(plot_dir, exist_ok=True)

plot_dir = "/home/belle2/vikasraj/b2kpi/KSPI0_TDCPV/belle/new_fits_3/results/plots/combined/csobdtmu"
os.makedirs(plot_dir, exist_ok=True)


dataset_qr = [
    dataset_0, dataset_1, dataset_2,
    dataset_3, dataset_4, dataset_5,
    dataset_6
]

for rbin in range(7):
    pdf = ws.pdf(f"pdfc_{rbin}")
    dataset = dataset_qr[rbin]

    fu.plot_pdf_new_1(
        pdf=pdf,
        dataset=dataset,
        var=ws.var("mod_mbc"),
        components=["sig_pdf_tot", "cont_pdf", "bbbar_pdf"],
        plotpath=f"{plot_dir}/mod_mbc_4d_kde_rbin{rbin}.png",
        # rootpath=f"{plot_dir}/deltae_kde_rbin{rbin}.root",
        legends=("Data", "Fit", "Signal", "q#bar{q}", "B#bar{B}"),
        #isData=False, isBelle=True, isLumi=True
    )

    fu.plot_pdf_new_1(
        pdf=pdf,
        dataset=dataset,
        var=ws.var("deltae"),
        components=["sig_pdf_tot", "cont_pdf", "bbbar_pdf"],
        plotpath=f"{plot_dir}/deltae_4d_kde_rbin{rbin}.png",
        # rootpath=f"{plot_dir}/deltae_kde_rbin{rbin}.root",
        legends=("Data", "Fit", "Signal", "q#bar{q}", "B#bar{B}"),
        #isData=False, isBelle=True, isLumi=True
    )

    fu.plot_pdf_new_1(
        pdf=pdf,
        dataset=dataset,
        var=ws.var("csobdtmu"),
        components=["sig_pdf_tot", "cont_pdf", "bbbar_pdf"],
        plotpath=f"{plot_dir}/csobdtmu_4d_kde_rbin{rbin}.png",
        # rootpath=f"{plot_dir}/deltae_kde_rbin{rbin}.root",
        legends=("Data", "Fit", "Signal", "q#bar{q}", "B#bar{B}"),
        #isData=False, isBelle=True, isLumi=True
    )


print("All Done!")