@@ -19,6 +19,7 @@ namespace albatross {
1919// which behave different conditional on the type of predictions desired.
2020template <typename T> struct PredictTypeIdentity { typedef T type; };
2121
22+ namespace detail {
2223/*
2324 * MeanPredictor is responsible for determining if a valid form of
2425 * predicting exists for a given set of model, feature, fit. The
@@ -33,8 +34,8 @@ class MeanPredictor {
3334 typename std::enable_if<
3435 has_valid_predict_mean<ModelType, FeatureType, FitType>::value,
3536 int >::type = 0 >
36- Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
37- const std::vector<FeatureType> &features) const {
37+ static Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
38+ const std::vector<FeatureType> &features) {
3839 return model.predict_ (features, fit,
3940 PredictTypeIdentity<Eigen::VectorXd>());
4041 }
@@ -46,8 +47,8 @@ class MeanPredictor {
4647 has_valid_predict_marginal<ModelType, FeatureType,
4748 FitType>::value,
4849 int >::type = 0 >
49- Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
50- const std::vector<FeatureType> &features) const {
50+ static Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
51+ const std::vector<FeatureType> &features) {
5152 return model
5253 .predict_ (features, fit, PredictTypeIdentity<MarginalDistribution>())
5354 .mean ;
@@ -61,8 +62,8 @@ class MeanPredictor {
6162 FitType>::value &&
6263 has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
6364 int >::type = 0 >
64- Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
65- const std::vector<FeatureType> &features) const {
65+ static Eigen::VectorXd _mean (const ModelType &model, const FitType &fit,
66+ const std::vector<FeatureType> &features) {
6667 return model
6768 .predict_ (features, fit, PredictTypeIdentity<JointDistribution>())
6869 .mean ;
@@ -75,9 +76,9 @@ class MarginalPredictor {
7576 typename std::enable_if<has_valid_predict_marginal<
7677 ModelType, FeatureType, FitType>::value,
7778 int >::type = 0 >
78- MarginalDistribution
79+ static MarginalDistribution
7980 _marginal (const ModelType &model, const FitType &fit,
80- const std::vector<FeatureType> &features) const {
81+ const std::vector<FeatureType> &features) {
8182 return model.predict_ (features, fit,
8283 PredictTypeIdentity<MarginalDistribution>());
8384 }
@@ -88,9 +89,9 @@ class MarginalPredictor {
8889 !has_valid_predict_marginal<ModelType, FeatureType, FitType>::value &&
8990 has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
9091 int >::type = 0 >
91- MarginalDistribution
92+ static MarginalDistribution
9293 _marginal (const ModelType &model, const FitType &fit,
93- const std::vector<FeatureType> &features) const {
94+ const std::vector<FeatureType> &features) {
9495 const auto joint_pred =
9596 model.predict_ (features, fit, PredictTypeIdentity<JointDistribution>());
9697 return joint_pred.marginal ();
@@ -103,13 +104,56 @@ class JointPredictor {
103104 typename std::enable_if<
104105 has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
105106 int >::type = 0 >
106- JointDistribution _joint (const ModelType &model, const FitType &fit,
107- const std::vector<FeatureType> &features) const {
107+ static JointDistribution _joint (const ModelType &model, const FitType &fit,
108+ const std::vector<FeatureType> &features) {
108109 return model.predict_ (features, fit,
109110 PredictTypeIdentity<JointDistribution>());
110111 }
111112};
112113
114+ template <
115+ typename ModelType, typename FeatureType, typename FitType,
116+ typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
117+ FeatureType, FitType>::value,
118+ int >::type = 0 >
119+ auto make_prediction (const ModelType &model, const FitType &fit,
120+ const std::vector<FeatureType> &features,
121+ PredictTypeIdentity<JointDistribution> &&) {
122+ return JointPredictor::_joint (model, fit, features);
123+ }
124+
125+ template <
126+ typename ModelType, typename FeatureType, typename FitType,
127+ typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
128+ FeatureType, FitType>::value,
129+ int >::type = 0 >
130+ auto make_prediction (const ModelType &model, const FitType &fit,
131+ const std::vector<FeatureType> &features,
132+ PredictTypeIdentity<MarginalDistribution> &&) {
133+ return MarginalPredictor::_marginal (model, fit, features);
134+ }
135+
136+ template <typename ModelType, typename FeatureType, typename FitType,
137+ typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
138+ FeatureType, FitType>::value,
139+ int >::type = 0 >
140+ auto make_prediction (const ModelType &model, const FitType &fit,
141+ const std::vector<FeatureType> &features,
142+ PredictTypeIdentity<Eigen::VectorXd> &&) {
143+ return MeanPredictor::_mean (model, fit, features);
144+ }
145+ } // namespace detail
146+
147+ template <typename PredictType, typename ModelType, typename FeatureType,
148+ typename FitType>
149+ auto make_prediction (
150+ const ModelType &model, const FitType &fit,
151+ const std::vector<FeatureType> &features,
152+ PredictTypeIdentity<PredictType> = PredictTypeIdentity<PredictType>()) {
153+ return detail::make_prediction (model, fit, features,
154+ PredictTypeIdentity<PredictType>());
155+ }
156+
113157template <typename ModelType, typename FeatureType, typename FitType>
114158class Prediction {
115159
@@ -126,62 +170,30 @@ class Prediction {
126170 : model_(std::move(model)), fit_(std::move(fit)), features_(features) {}
127171
128172 // Mean
129- template <typename DummyType = FeatureType,
130- typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
131- DummyType, FitType>::value,
132- int >::type = 0 >
133- Eigen::VectorXd mean () const {
173+ template <typename DummyType = FeatureType> Eigen::VectorXd mean () const {
134174 static_assert (std::is_same<DummyType, FeatureType>::value,
135175 " never do prediction.mean<T>()" );
136- return MeanPredictor ()._mean (model_, fit_, features_);
176+ return make_prediction (model_, fit_, features_,
177+ PredictTypeIdentity<Eigen::VectorXd>());
137178 }
138179
139- template <
140- typename DummyType = FeatureType,
141- typename std::enable_if<!can_predict_mean<MeanPredictor, ModelType,
142- DummyType, FitType>::value,
143- int >::type = 0 >
144- Eigen::VectorXd mean () const = delete; // No valid predict method found.
145-
146180 // Marginal
147- template <
148- typename DummyType = FeatureType,
149- typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
150- DummyType, FitType>::value,
151- int >::type = 0 >
181+ template <typename DummyType = FeatureType>
152182 MarginalDistribution marginal () const {
153183 static_assert (std::is_same<DummyType, FeatureType>::value,
154184 " never do prediction.mean<T>()" );
155- return MarginalPredictor ()._marginal (model_, fit_, features_);
185+ return make_prediction (model_, fit_, features_,
186+ PredictTypeIdentity<MarginalDistribution>());
156187 }
157188
158- template <typename DummyType = FeatureType,
159- typename std::enable_if<
160- !can_predict_marginal<MarginalPredictor, ModelType, DummyType,
161- FitType>::value,
162- int >::type = 0 >
163- MarginalDistribution
164- marginal () const = delete ; // No valid predict method found.
165-
166189 // Joint
167- template <
168- typename DummyType = FeatureType,
169- typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
170- DummyType, FitType>::value,
171- int >::type = 0 >
172- JointDistribution joint () const {
190+ template <typename DummyType = FeatureType> JointDistribution joint () const {
173191 static_assert (std::is_same<DummyType, FeatureType>::value,
174192 " never do prediction.mean<T>()" );
175- return JointPredictor ()._joint (model_, fit_, features_);
193+ return make_prediction (model_, fit_, features_,
194+ PredictTypeIdentity<JointDistribution>());
176195 }
177196
178- template <
179- typename DummyType = FeatureType,
180- typename std::enable_if<!can_predict_joint<JointPredictor, ModelType,
181- DummyType, FitType>::value,
182- int >::type = 0 >
183- JointDistribution joint () const = delete; // No valid predict method found.
184-
185197 template <typename PredictType>
186198 PredictType get (PredictTypeIdentity<PredictType> =
187199 PredictTypeIdentity<PredictType>()) const {
0 commit comments