From ece7c032f3d292a9bcdaac9d49a6c8bf9f47c13f Mon Sep 17 00:00:00 2001 From: Simon Cello <87865105+Phmonski@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:50:21 +0100 Subject: [PATCH] [RF][HS3] Patched ParamHistFuncs | Minor clean up * Fix Constant Flag in Data axes * ParamHistFuncs accept custom modifiers * Store binning Information with ParamHistFunc * Add Warning to ParamHistFunc Compatibility * clang-format clean-up * Update warning message for histogram binnings version * Update warning message for default binning (cherry picked from commit e42fe7c9fabd99851d455a139a518aadce73cf32) --- .../hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h | 4 +- roofit/hs3/src/JSONFactories_HistFactory.cxx | 12 +- roofit/hs3/src/JSONFactories_RooFitCore.cxx | 112 ++++++++++++++++++ roofit/hs3/src/RooFitHS3_wsexportkeys.cxx | 14 +-- .../src/RooFitHS3_wsfactoryexpressions.cxx | 14 +-- roofit/hs3/src/RooJSONFactoryWSTool.cxx | 30 ++--- .../inc/RooFit/Detail/JSONInterface.h | 6 + 7 files changed, 161 insertions(+), 31 deletions(-) diff --git a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h index c656d60f73358..d4e10d3ba1dc8 100644 --- a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h +++ b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h @@ -229,8 +229,8 @@ class RooJSONFactoryWSTool { void importVariable(const RooFit::Detail::JSONNode &p); void importDependants(const RooFit::Detail::JSONNode &n); - void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &p); - void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n); + void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins); + void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins); void exportAllObjects(RooFit::Detail::JSONNode &n); diff --git a/roofit/hs3/src/JSONFactories_HistFactory.cxx b/roofit/hs3/src/JSONFactories_HistFactory.cxx index 2634b752f3535..a294050033b16 100644 --- a/roofit/hs3/src/JSONFactories_HistFactory.cxx +++ b/roofit/hs3/src/JSONFactories_HistFactory.cxx @@ -823,6 +823,16 @@ void collectElements(RooArgSet &elems, RooAbsArg *arg) } } +bool allRooRealVar(const RooAbsCollection &list) +{ + for (auto *var : list) { + if (!dynamic_cast(var)) { + return false; + } + } + return true; +} + struct Sample { std::string name; std::vector hist; @@ -920,7 +930,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons addNormFactor(par, sample, ws); } else if (auto hf = dynamic_cast(e)) { updateObservables(hf->dataHist()); - } else if (auto phf = dynamic_cast(e)) { + } else if (ParamHistFunc *phf = dynamic_cast(e); phf && allRooRealVar(phf->paramList())) { phfs.push_back(phf); } else if (auto fip = dynamic_cast(e)) { // some (modified) histfactory models have several instances of FlexibleInterpVar diff --git a/roofit/hs3/src/JSONFactories_RooFitCore.cxx b/roofit/hs3/src/JSONFactories_RooFitCore.cxx index 7cc8754df560b..a149f647ee34b 100644 --- a/roofit/hs3/src/JSONFactories_RooFitCore.cxx +++ b/roofit/hs3/src/JSONFactories_RooFitCore.cxx @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -532,6 +534,71 @@ class RooMultiVarGaussianFactory : public RooFit::JSONIO::Importer { } }; +class ParamHistFuncFactory : public RooFit::JSONIO::Importer { +public: + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override + { + std::string name(RooJSONFactoryWSTool::name(p)); + RooArgList varList = tool->requestArgList(p, "variables"); + if (!p.has_child("axes")) { + std::stringstream ss; + ss << "No axes given in '" << name << "'" + << ". Using default binning (uniform; nbins=100). If needed, export the Workspace to JSON with a newer " + << "Root version that supports custom ParamHistFunc binnings(>=6.38.00)." << std::endl; + RooJSONFactoryWSTool::warning(ss.str()); + tool->wsEmplace(name, varList, tool->requestArgList(p, "parameters")); + return true; + } + tool->wsEmplace(name, readBinning(p, varList), tool->requestArgList(p, "parameters")); + return true; + } + +private: + RooArgList readBinning(const JSONNode &topNode, const RooArgList &varList) const + { + // Temporary map from variable name → RooRealVar + std::map> varMap; + + // Build variables from JSON + for (const JSONNode &node : topNode["axes"].children()) { + const std::string name = node["name"].val(); + std::unique_ptr obs; + + if (node.has_child("edges")) { + std::vector edges; + for (const auto &bound : node["edges"].children()) { + edges.push_back(bound.val_double()); + } + obs = std::make_unique(name.c_str(), name.c_str(), edges.front(), edges.back()); + RooBinning bins(obs->getMin(), obs->getMax()); + for (auto b : edges) + bins.addBoundary(b); + obs->setBinning(bins); + } else { + obs = std::make_unique(name.c_str(), name.c_str(), node["min"].val_double(), + node["max"].val_double()); + obs->setBins(node["nbins"].val_int()); + } + + varMap[name] = std::move(obs); + } + + // Now build the final list following the order in varList + RooArgList vars; + for (int i = 0; i < varList.getSize(); ++i) { + const auto *refVar = dynamic_cast(varList.at(i)); + if (!refVar) + continue; + + auto it = varMap.find(refVar->GetName()); + if (it != varMap.end()) { + vars.addOwned(std::move(it->second)); // preserve ownership + } + } + return vars; + } +}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // specialized exporter implementations /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -696,6 +763,7 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter { expr.ReplaceAll("TMath::Sin", "sin"); expr.ReplaceAll("TMath::Sqrt", "sqrt"); expr.ReplaceAll("TMath::Power", "pow"); + expr.ReplaceAll("TMath::Erf", "erf"); } }; template @@ -952,6 +1020,47 @@ class RooExtendPdfStreamer : public RooFit::JSONIO::Exporter { } }; +class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + RooJSONFactoryWSTool::fillSeq(elem["variables"], pdf->dataVars()); + RooJSONFactoryWSTool::fillSeq(elem["parameters"], pdf->paramList()); + writeBinningInfo(pdf, elem); + return true; + } + +private: + void writeBinningInfo(const ParamHistFunc *pdf, JSONNode &elem) const + { + auto &observablesNode = elem["axes"].set_seq(); + // axes have to be ordered to get consistent bin indices + for (auto *var : static_range_cast(pdf->dataVars())) { + std::string name = var->GetName(); + RooJSONFactoryWSTool::testValidName(name, false); + JSONNode &obsNode = observablesNode.append_child().set_map(); + obsNode["name"] << name; + if (var->getBinning().isUniform()) { + obsNode["min"] << var->getMin(); + obsNode["max"] << var->getMax(); + obsNode["nbins"] << var->getBins(); + } else { + auto &edges = obsNode["edges"]; + edges.set_seq(); + double val = var->getBinning().binLow(0); + edges.append_child() << val; + for (int i = 0; i < var->getBinning().numBins(); ++i) { + val = var->getBinning().binHigh(i); + edges.append_child() << val; + } + } + } + } +}; + #define DEFINE_EXPORTER_KEY(class_name, name) \ std::string const &class_name::key() const \ { \ @@ -989,6 +1098,7 @@ DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral"); DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative"); DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf"); DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf"); +DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step"); /////////////////////////////////////////////////////////////////////////////////////////////////////// // instantiate all importers and exporters @@ -1021,6 +1131,7 @@ STATIC_EXECUTE([]() { registerImporter("derivative", false); registerImporter("fft_conv_pdf", false); registerImporter("extend_pdf", false); + registerImporter("step", false); registerExporter>(RooAddPdf::Class(), false); registerExporter>(RooAddModel::Class(), false); @@ -1047,6 +1158,7 @@ STATIC_EXECUTE([]() { registerExporter(RooDerivative::Class(), false); registerExporter(RooFFTConvPdf::Class(), false); registerExporter(RooExtendPdf::Class(), false); + registerExporter(ParamHistFunc::Class(), false); }); } // namespace diff --git a/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx b/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx index dacbf6dbf6456..20423610811ac 100644 --- a/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx +++ b/roofit/hs3/src/RooFitHS3_wsexportkeys.cxx @@ -62,6 +62,13 @@ auto RooFitHS3_wsexportkeys = R"({ "sigmaR": "sigma_R" } }, + "RooEffProd": { + "type": "efficiency_product_pdf_dist", + "proxies": { + "pdf": "pdf", + "eff": "eff" + } + }, "RooGamma": { "type": "gamma_dist", "proxies": { @@ -79,13 +86,6 @@ auto RooFitHS3_wsexportkeys = R"({ "sigma": "sigma" } }, - "ParamHistFunc": { - "type": "step", - "proxies": { - "dataVars": "variables", - "paramSet": "parameters" - } - }, "RooLandau": { "type": "landau_dist", "proxies": { diff --git a/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx b/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx index 62c27a0038fca..6076465bcb8ea 100644 --- a/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx +++ b/roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx @@ -43,6 +43,13 @@ auto RooFitHS3_wsfactoryexpressions = R"({ "coefficients" ] }, + "efficiency_product_pdf_dist": { + "class": "RooEffProd", + "arguments": [ + "pdf", + "eff" + ] + }, "gamma_dist": { "class": "RooGamma", "arguments": [ @@ -112,13 +119,6 @@ auto RooFitHS3_wsfactoryexpressions = R"({ "observables" ] }, - "step": { - "class": "ParamHistFunc", - "arguments": [ - "variables", - "parameters" - ] - }, "sum": { "class": "RooAddition", "arguments": [ diff --git a/roofit/hs3/src/RooJSONFactoryWSTool.cxx b/roofit/hs3/src/RooJSONFactoryWSTool.cxx index b41f537d7f01b..35a86f109dbbd 100644 --- a/roofit/hs3/src/RooJSONFactoryWSTool.cxx +++ b/roofit/hs3/src/RooJSONFactoryWSTool.cxx @@ -965,7 +965,7 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl(const std::string &obj * @param node The JSONNode to which the variable will be exported. * @return void */ -void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node) +void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, bool storeConstant, bool storeBins) { auto *cv = dynamic_cast(v); auto *rrv = dynamic_cast(v); @@ -984,10 +984,10 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node) var["const"] << true; } else if (rrv) { var["value"] << rrv->getVal(); - if (rrv->isConstant()) { + if (rrv->isConstant() && storeConstant) { var["const"] << rrv->isConstant(); } - if (rrv->getBins() != 100) { + if (rrv->getBins() != 100 && storeBins) { var["nbins"] << rrv->getBins(); } _domains->readVariable(*rrv); @@ -1004,12 +1004,12 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node) * @param n The JSONNode to which the variables will be exported. * @return void */ -void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n) +void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n, bool storeConstant, bool storeBins) { // export a list of RooRealVar objects n.set_seq(); for (RooAbsArg *arg : allElems) { - exportVariable(arg, n); + exportVariable(arg, n, storeConstant, storeBins); } } @@ -1070,7 +1070,7 @@ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set(&func) || dynamic_cast(&func)) { - exportVariable(&func, *_varsNode); + exportVariable(&func, *_varsNode, true, false); return; } @@ -1554,7 +1554,7 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) // this really is an unbinned dataset output["type"] << "unbinned"; - exportVariables(variables, output["axes"]); + exportVariables(variables, output["axes"], false, true); auto &coords = output["entries"].set_seq(); std::vector weightVals; bool hasNonUnityWeights = false; @@ -1562,10 +1562,6 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) data.get(i); coords.append_child().fill_seq(variables, [](auto x) { return static_cast(x)->getVal(); }); std::string datasetName = data.GetName(); - /*if (datasetName.find("combData_ZvvH126.5") != std::string::npos) { - file << dynamic_cast(data.get(i)->find("atlas_invMass_PttEtaConvVBFCat1"))->getVal() << - std::endl; - }*/ if (data.isWeighted()) { weightVals.push_back(data.weight()); if (data.weight() != 1.) @@ -1575,7 +1571,6 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) if (data.isWeighted() && hasNonUnityWeights) { output["weights"].fill_seq(weightVals); } - // file.close(); } /** @@ -1960,7 +1955,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n) snapshotSorted.sort(); std::string name(snsh->GetName()); if (name != "default_values") { - this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"]); + this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"], true, + false); } } _varsNode = nullptr; @@ -2240,8 +2236,14 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n) combineDatasets(*_rootnodeInput, datasets); for (auto const &d : datasets) { - if (d) + if (d) { _workspace.import(*d); + for (auto const &obs : *d->get()) { + if (auto *rrv = dynamic_cast(obs)) { + _workspace.var(rrv->GetName())->setBinning(rrv->getBinning()); + } + } + } } _rootnodeInput = nullptr; diff --git a/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h b/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h index 27b46b9feb51e..245673fcf577f 100644 --- a/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h +++ b/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h @@ -265,6 +265,12 @@ inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::sp return n; } +inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::span v) +{ + n.fill_seq(v); + return n; +} + template RooFit::Detail::JSONNode & operator<<(RooFit::Detail::JSONNode &n, const std::unordered_map &m)