#!/usr/bin/env python3


import multiprocessing
import ROOT
from multiprocessing.dummy import Pool as ThreadPool 
import math

def fit_thermal(x, par):
    dndy = par[0]
    T    = par[1]
    m0   = par[2]
    mt   = math.sqrt(x[0]*x[0] + m0*m0)
    val  = (dndy/(T*(m0 + T))) * math.exp(-(mt - m0)/T)
    return val


def apply_fit(hist, fit, *frange):
    # ~ fit.SetParameter(0,1) # <--- Workaround waiting for https://github.com/root-project/root/issues/16184
    for npar in range(fit.GetNpar()):
        fit.SetParameter(npar, fit.GetParameter(npar))
    hist.Fit(fit, "", "", *frange)
    return fit

def main():
    c = ROOT.TCanvas()
    h = ROOT.TH1F("myHist", "myTitle", 64, -4, 4)
    h.FillRandom("gaus")
    fit_range = [-3, 3]

    thermal = ROOT.TF1("thermal", fit_thermal, 0, 2, 3)
    thermal.SetParNames("dN/dy", "T", "m0")
    thermal.SetParameters(20, 0.1, 0.938)

    f1 = ROOT.TF1("f1", "gaus")
    f2 = ROOT.TF1("f2", "expo")
    f3 = ROOT.TF1("f3", "pol3")
    f4 = ROOT.TF1("f4", "landau")

    # ~ pool = ThreadPool(4)
    pool = multiprocessing.Pool(processes=4)
    list1 = (h, f1, *fit_range)
    list2 = (h, f2, *fit_range)
    list3 = (h, f3, *fit_range)
    list4 = (h, thermal, *fit_range)

    result = pool.starmap(apply_fit, [list1,list2,list3,list4])
    pool.close()
    pool.join()
    print(result)

    h.Draw()
    result[0].Draw('same')
    result[1].Draw('same')
    result[2].Draw('same')
    result[3].Draw('same')
    c.Print('test.pdf')

if __name__ == "__main__":
    main()
