#!/usr/bin/env python
# coding: utf-8

# # Fit templates 2d


#run with python3 Constrain_Fit_histPDF_2d.py 0 8 output 1 fileoutput

import ROOT
import numpy as np

ROOT.gROOT.SetBatch(True)
import sys

#sys.path.append("/home/marcela/root_functions/")
#import root_functions as rf

sig_events=float(sys.argv[1])
bkg_events=float(sys.argv[2])
output_file=sys.argv[3]
set_seed=int(sys.argv[4])
path_histos= '/home/marcela/Documents/optNewLines2022/muon_results/HistosData/'
data_file=sys.argv[5]

bines=20

a=-3
b=3
c=-3
d=3
sigma=2 #2sigmas

import glob
import root_pandas

file_input='/home/marcela/Documents/optNewLines2022/muon_eph3/template_sig.root'
templatesfile=ROOT.TFile(file_input,'READ')
h2sig = templatesfile.Get('h2sig')

import matplotlib.pyplot as plt

x = ROOT.RooRealVar('rot_scale_x', 'rot_scale_x', a, b, "")
y = ROOT.RooRealVar('rot_scale_y', 'rot_scale_y', c, d, "")

from pyroofit.data import df2roo
from pyroofit.data import roo2hist

#########################
##########################
datahist_sig= ROOT.RooDataHist("sigD", "sigD", ROOT.RooArgList(x,y), h2sig )
sigpdf = ROOT.RooHistPdf("sigpdf","sigpdf",ROOT.RooArgSet(x,y),datahist_sig,0)

if sig_events in [0, 0.1] :
    rs1=-25
    rs2=25 
else:
    rs1=-10*sig_events
    rs2=10*sig_events
if bkg_events in [0, 0.1]:
    rb1=-10
    rb2=25
else:
    rb1=-5*bkg_events
    rb2=10*bkg_events

print("------------------------------")
print("range", rs1,rs2,rb1,rb2)
print("------------------------------")
sqbkg=round(np.sqrt(bkg_events),2)

##########Model################
myWS = ROOT.RooWorkspace("myWS")
myWS.var("rot_scale_x")
myWS.var("rot_scale_y")
getattr(myWS,'import')(sigpdf)

myWS.factory(f"n_sig[{rs1},{rs2}]")
myWS.factory(f"n_bkg[{rb1},{rb2}]")
myWS.factory("Uniform::gx(rot_scale_x)")
myWS.factory("Uniform::gy(rot_scale_y)")
myWS.factory("PROD::bkgpdf(gx,gy)")
myWS.factory(f"SUM:model_sum(n_sig*sigpdf, n_bkg*bkgpdf)")
myWS.factory(f"Gaussian::constrain_bkg(n_bkg,{bkg_events},{sqbkg})")
myWS.factory(f"PROD:model(model_sum,constrain_bkg)")

####### mc ##########333
set_n_sig=ROOT.RooArgSet(myWS.var("n_sig"))
set_n_bkg=ROOT.RooArgSet(myWS.var("n_bkg"))

mc = ROOT.RooStats.ModelConfig("ModelConfig",myWS)
mc.SetPdf(myWS.pdf('model'))
mc.SetParametersOfInterest(set_n_sig)
mc.SetObservables(ROOT.RooArgSet(myWS.var("rot_scale_x"),myWS.var("rot_scale_y")))
mc.SetNuisanceParameters(set_n_bkg)
mc.SetSnapshot(ROOT.RooArgSet(myWS.var("n_sig")))
mc.Print()
getattr(myWS,'import')(mc)
print("----------------------ok")

########################Generate data####################
print(sig_events)
print(bkg_events)
myWS.var("n_sig").setVal(sig_events)
myWS.var("n_bkg").setVal(bkg_events)
ROOT.RooRandom.randomGenerator().SetSeed(set_seed)
print("----------------------aftersetval")
print(myWS.var("n_sig").getVal())
print(myWS.var("n_bkg").getVal())
print("----------------------ok")
# Make a dataset pdf 'model' and import it in the workspace
data = myWS.pdf("model").generate(ROOT.RooArgSet(myWS.var("rot_scale_x"),myWS.var("rot_scale_y")))
########################################################
data.SetName("obsData")
getattr(myWS,'import')(data)
myWS.Print()
print("----------------------After set val, generate")
print(myWS.var("n_sig").getVal())
print(myWS.var("n_bkg").getVal())
print(myWS.var("n_sig").getRange())
print(myWS.var("n_bkg").getRange())
print("----------------------ok")
print("------------------------------------------")

##########save the workspace in a ROOT file
myWS.SetName("w")
output_file_root=output_file+f'_{set_seed}_workspace.root'
myWS.writeToFile(output_file_root, True)


myWS.Print()

pdf = myWS.pdf('model')
fitResult = pdf.fitTo(data,ROOT.RooFit.Extended(True),ROOT.RooFit.Constrain(ROOT.RooArgSet(myWS.var("n_bkg"))), ROOT.RooFit.Minos(True),ROOT.RooFit.SumW2Error(False), ROOT.RooFit.Save(True), ROOT.RooFit.NumCPU(12)) #default minuit/migrat y strategy 1

fitResult.Print('v')

#create nll
nll = pdf.createNLL(data)

print('###########Final vals##################################')
v1=myWS.var("n_bkg").getVal()
v2=myWS.var("n_bkg").getError()
v11=myWS.var("n_sig").getVal()
v22=myWS.var("n_sig").getError()
print(v1,v2)
print(v11, v22)
print('#############################################')

nllplots = True
if nllplots:
    rb2_5s = v1 + (5*v2)
    rb1_5s = v1 - (5*v2)
    rs2_5s = v11 + (5*v22)
    rs1_5s = v11 - (5*v22)
    canvas_nevv = ROOT.TCanvas("canvas_nevv", "canvas_nevv", 800, 720)
    canvas_nevv.Divide(2,1)
    canvas_nevv.cd(1)
    nevv_frame = myWS.var("n_sig").frame(ROOT.RooFit.Title(" "))
    nevv_frame.GetXaxis().SetRangeUser(rs1_5s,rs2_5s)
    profile_ll_nevv = nll.createProfile(myWS.var("n_sig"))
    nll.plotOn(nevv_frame,ROOT.RooFit.ShiftToZero())
    profile_ll_nevv.plotOn(nevv_frame, ROOT.RooFit.LineColor(ROOT.kRed), ROOT.RooFit.LineStyle(ROOT.kDashed))
    nevv_frame.GetYaxis().SetTitle('-log likelihood')
    nevv_frame.SetMinimum(0)
    nevv_frame.SetMaximum(10)
    nevv_frame.Draw()
    canvas_nevv.cd(2)
    nbkg_frame = myWS.var("n_bkg").frame(ROOT.RooFit.Title(" "))
    nbkg_frame.GetXaxis().SetRangeUser(rb1_5s,rb2_5s)
    profile_ll_nbkg = nll.createProfile(myWS.var("n_bkg"))
    nll.plotOn(nbkg_frame,ROOT.RooFit.ShiftToZero())
    profile_ll_nbkg.plotOn(nbkg_frame, ROOT.RooFit.LineColor(ROOT.kRed), ROOT.RooFit.LineStyle(ROOT.kDashed))
    nbkg_frame.GetYaxis().SetTitle('-log likelihood')
    nbkg_frame.SetMinimum(0)
    nbkg_frame.SetMaximum(10)
    nbkg_frame.Draw()
    canvas_nevv.SaveAs(output_file+f"smfit_nbkg_nll_pll_sumw2_{set_seed}.png")

    
