Skip to content

Commit 4c545d7

Browse files
authored
[RF][HS3] Clean Expressions in RooGenPdf and RooFormularVar
- Changes to Expressions to exclude TMath:: - Code clean-up - Include RooResolutionModels to nameSanitization - Data axes changed due to sorting variable names when reading .json Files This commit fixes HS3 Issue: scipp-atlas/pyhs3#69 (comment)
1 parent 68bfed5 commit 4c545d7

File tree

2 files changed

+31
-40
lines changed

2 files changed

+31
-40
lines changed

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,6 @@ class RooAddPdfStreamer : public RooFit::JSONIO::Exporter {
543543
{
544544
const RooArg_t *pdf = static_cast<const RooArg_t *>(func);
545545
elem["type"] << key();
546-
std::string name = elem["name"].val();
547-
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
548-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["summands"], pdf->pdfList());
549-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
550-
*/
551-
elem["name"] << name;
552546
RooJSONFactoryWSTool::fillSeq(elem["summands"], pdf->pdfList());
553547
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
554548
elem["extended"] << (pdf->extendMode() != RooArg_t::CanNotBeExtended);
@@ -563,12 +557,6 @@ class RooRealSumPdfStreamer : public RooFit::JSONIO::Exporter {
563557
{
564558
const RooRealSumPdf *pdf = static_cast<const RooRealSumPdf *>(func);
565559
elem["type"] << key();
566-
std::string name = elem["name"].val();
567-
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
568-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList());
569-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
570-
*/
571-
elem["name"] << name;
572560
RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList());
573561
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
574562
elem["extended"] << (pdf->extendMode() != RooAbsPdf::CanNotBeExtended);
@@ -583,12 +571,6 @@ class RooRealSumFuncStreamer : public RooFit::JSONIO::Exporter {
583571
{
584572
const RooRealSumFunc *pdf = static_cast<const RooRealSumFunc *>(func);
585573
elem["type"] << key();
586-
std::string name = elem["name"].val();
587-
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
588-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList());
589-
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
590-
*/
591-
elem["name"] << name;
592574
RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList());
593575
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
594576
return true;
@@ -687,6 +669,7 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
687669
const RooArg_t *pdf = static_cast<const RooArg_t *>(func);
688670
elem["type"] << key();
689671
TString expression(pdf->expression());
672+
cleanExpression(expression);
690673
// If the tokens follow the "x[#]" convention, the square braces enclosing each number
691674
// ensures that there is a unique mapping between the token and parameter name
692675
// If the tokens follow the "@#" convention, the numbers are not enclosed by braces.
@@ -701,6 +684,19 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
701684
elem["expression"] << expression.Data();
702685
return true;
703686
}
687+
688+
private:
689+
void cleanExpression(TString &expr) const
690+
{
691+
expr.ReplaceAll("TMath::Exp", "exp");
692+
expr.ReplaceAll("TMath::Min", "min");
693+
expr.ReplaceAll("TMath::Max", "max");
694+
expr.ReplaceAll("TMath::Log", "log");
695+
expr.ReplaceAll("TMath::Cos", "cos");
696+
expr.ReplaceAll("TMath::Sin", "sin");
697+
expr.ReplaceAll("TMath::Sqrt", "sqrt");
698+
expr.ReplaceAll("TMath::Power", "pow");
699+
}
704700
};
705701
template <class RooArg_t>
706702
class RooPolynomialStreamer : public RooFit::JSONIO::Exporter {
@@ -784,9 +780,6 @@ class RooTruthModelStreamer : public RooFit::JSONIO::Exporter {
784780
{
785781
auto *pdf = static_cast<const RooTruthModel *>(func);
786782
elem["type"] << key();
787-
std::string name = elem["name"].val();
788-
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
789-
elem["name"] << name;
790783
elem["x"] << pdf->convVar().GetName();
791784

792785
return true;
@@ -800,9 +793,6 @@ class RooGaussModelStreamer : public RooFit::JSONIO::Exporter {
800793
{
801794
auto *pdf = static_cast<const RooGaussModel *>(func);
802795
elem["type"] << key();
803-
std::string name = elem["name"].val();
804-
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
805-
elem["name"] << name;
806796
elem["x"] << pdf->convVar().GetName();
807797
elem["mean"] << pdf->getMean().GetName();
808798
elem["sigma"] << pdf->getSigma().GetName();
@@ -913,10 +903,6 @@ class RooRealIntegralStreamer : public RooFit::JSONIO::Exporter {
913903
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
914904
{
915905
auto *integral = static_cast<const RooRealIntegral *>(func);
916-
std::string name = elem["name"].val();
917-
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
918-
elem["name"] << name;
919-
920906
elem["type"] << key();
921907
std::string integrand = integral->integrand().GetName();
922908
// elem["integrand"] << RooJSONFactoryWSTool::sanitizeName(integrand);
@@ -1060,7 +1046,7 @@ STATIC_EXECUTE([]() {
10601046
registerExporter<RooRealIntegralStreamer>(RooRealIntegral::Class(), false);
10611047
registerExporter<RooDerivativeStreamer>(RooDerivative::Class(), false);
10621048
registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false);
1063-
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
1049+
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
10641050
});
10651051

10661052
} // namespace

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -483,18 +483,13 @@ void exportAttributes(const RooAbsArg *arg, JSONNode &rootnode)
483483
*
484484
* @param ws The RooWorkspace in which the observables will be created.
485485
* @param node The JSONNode containing information about the observables to be created.
486-
* @param out The RooArgSet to which the created observables will be added.
486+
* @param out The RooAbsCollection to which the created observables will be added.
487487
* @return void
488488
*/
489-
void getObservables(RooWorkspace const &ws, const JSONNode &node, RooArgSet &out)
489+
void getObservables(RooWorkspace const &ws, const JSONNode &node, RooAbsCollection &out)
490490
{
491-
std::map<std::string, Var> vars;
492491
for (const auto &p : node["axes"].children()) {
493-
vars.emplace(RooJSONFactoryWSTool::name(p), Var(p));
494-
}
495-
496-
for (auto v : vars) {
497-
std::string name(v.first);
492+
std::string name(RooJSONFactoryWSTool::name(p));
498493
if (ws.var(name)) {
499494
out.add(*ws.var(name));
500495
} else {
@@ -528,9 +523,9 @@ std::unique_ptr<RooAbsData> loadData(const JSONNode &p, RooWorkspace &workspace)
528523
return RooJSONFactoryWSTool::readBinnedData(p, name, RooJSONFactoryWSTool::readAxes(p));
529524
} else if (type == "unbinned") {
530525
// unbinned
531-
RooArgSet vars;
532-
getObservables(workspace, p, vars);
533-
RooArgList varlist(vars);
526+
RooArgList varlist;
527+
getObservables(workspace, p, varlist);
528+
RooArgSet vars(varlist);
534529
auto data = std::make_unique<RooDataSet>(name, name, vars, RooFit::WeightVar());
535530
auto &coords = p["entries"];
536531
if (!coords.is_seq()) {
@@ -2503,6 +2498,10 @@ RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyMode
25032498
tmpWS.import(*obj);
25042499
}
25052500

2501+
for (auto *obj : ws.allResolutionModels()) {
2502+
tmpWS.import(*obj);
2503+
}
2504+
25062505
/*
25072506
if (auto* mc = dynamic_cast<RooStats::ModelConfig*>(obj)) {
25082507
// Import the PDF
@@ -2578,6 +2577,12 @@ RooWorkspace RooJSONFactoryWSTool::sanitizeWS(const RooWorkspace &ws)
25782577
}
25792578
}
25802579

2580+
// Resolution Models
2581+
for (auto *obj : tmpWS.allResolutionModels()) {
2582+
if (!isValidName(obj->GetName())) {
2583+
obj->SetName(sanitizeName(obj->GetName()).c_str());
2584+
}
2585+
}
25812586
// Datasets
25822587
for (auto *data : tmpWS.allData()) {
25832588
// Sanitize dataset name

0 commit comments

Comments
 (0)