Skip to content

Commit 173652d

Browse files
committed
make_prediction
1 parent 5aac5ea commit 173652d

File tree

3 files changed

+80
-66
lines changed

3 files changed

+80
-66
lines changed

include/albatross/src/core/model.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ constexpr bool DEFAULT_USE_ASYNC = false;
2121

2222
template <typename ModelType> class ModelBase : public ParameterHandlingMixin {
2323

24-
friend class JointPredictor;
25-
friend class MarginalPredictor;
26-
friend class MeanPredictor;
24+
friend class detail::JointPredictor;
25+
friend class detail::MarginalPredictor;
26+
friend class detail::MeanPredictor;
2727

2828
template <typename T, typename FeatureType> friend class fit_model_type;
2929

include/albatross/src/core/prediction.hpp

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace albatross {
1919
// which behave different conditional on the type of predictions desired.
2020
template <typename T> struct PredictTypeIdentity { typedef T type; };
2121

22+
namespace detail {
2223
/*
2324
* MeanPredictor is responsible for determining if a valid form of
2425
* predicting exists for a given set of model, feature, fit. The
@@ -33,8 +34,8 @@ class MeanPredictor {
3334
typename std::enable_if<
3435
has_valid_predict_mean<ModelType, FeatureType, FitType>::value,
3536
int>::type = 0>
36-
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
37-
const std::vector<FeatureType> &features) const {
37+
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
38+
const std::vector<FeatureType> &features) {
3839
return model.predict_(features, fit,
3940
PredictTypeIdentity<Eigen::VectorXd>());
4041
}
@@ -46,8 +47,8 @@ class MeanPredictor {
4647
has_valid_predict_marginal<ModelType, FeatureType,
4748
FitType>::value,
4849
int>::type = 0>
49-
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
50-
const std::vector<FeatureType> &features) const {
50+
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
51+
const std::vector<FeatureType> &features) {
5152
return model
5253
.predict_(features, fit, PredictTypeIdentity<MarginalDistribution>())
5354
.mean;
@@ -61,8 +62,8 @@ class MeanPredictor {
6162
FitType>::value &&
6263
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
6364
int>::type = 0>
64-
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
65-
const std::vector<FeatureType> &features) const {
65+
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
66+
const std::vector<FeatureType> &features) {
6667
return model
6768
.predict_(features, fit, PredictTypeIdentity<JointDistribution>())
6869
.mean;
@@ -75,9 +76,9 @@ class MarginalPredictor {
7576
typename std::enable_if<has_valid_predict_marginal<
7677
ModelType, FeatureType, FitType>::value,
7778
int>::type = 0>
78-
MarginalDistribution
79+
static MarginalDistribution
7980
_marginal(const ModelType &model, const FitType &fit,
80-
const std::vector<FeatureType> &features) const {
81+
const std::vector<FeatureType> &features) {
8182
return model.predict_(features, fit,
8283
PredictTypeIdentity<MarginalDistribution>());
8384
}
@@ -88,9 +89,9 @@ class MarginalPredictor {
8889
!has_valid_predict_marginal<ModelType, FeatureType, FitType>::value &&
8990
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
9091
int>::type = 0>
91-
MarginalDistribution
92+
static MarginalDistribution
9293
_marginal(const ModelType &model, const FitType &fit,
93-
const std::vector<FeatureType> &features) const {
94+
const std::vector<FeatureType> &features) {
9495
const auto joint_pred =
9596
model.predict_(features, fit, PredictTypeIdentity<JointDistribution>());
9697
return joint_pred.marginal();
@@ -103,13 +104,56 @@ class JointPredictor {
103104
typename std::enable_if<
104105
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
105106
int>::type = 0>
106-
JointDistribution _joint(const ModelType &model, const FitType &fit,
107-
const std::vector<FeatureType> &features) const {
107+
static JointDistribution _joint(const ModelType &model, const FitType &fit,
108+
const std::vector<FeatureType> &features) {
108109
return model.predict_(features, fit,
109110
PredictTypeIdentity<JointDistribution>());
110111
}
111112
};
112113

114+
template <
115+
typename ModelType, typename FeatureType, typename FitType,
116+
typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
117+
FeatureType, FitType>::value,
118+
int>::type = 0>
119+
auto make_prediction(const ModelType &model, const FitType &fit,
120+
const std::vector<FeatureType> &features,
121+
PredictTypeIdentity<JointDistribution> &&) {
122+
return JointPredictor::_joint(model, fit, features);
123+
}
124+
125+
template <
126+
typename ModelType, typename FeatureType, typename FitType,
127+
typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
128+
FeatureType, FitType>::value,
129+
int>::type = 0>
130+
auto make_prediction(const ModelType &model, const FitType &fit,
131+
const std::vector<FeatureType> &features,
132+
PredictTypeIdentity<MarginalDistribution> &&) {
133+
return MarginalPredictor::_marginal(model, fit, features);
134+
}
135+
136+
template <typename ModelType, typename FeatureType, typename FitType,
137+
typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
138+
FeatureType, FitType>::value,
139+
int>::type = 0>
140+
auto make_prediction(const ModelType &model, const FitType &fit,
141+
const std::vector<FeatureType> &features,
142+
PredictTypeIdentity<Eigen::VectorXd> &&) {
143+
return MeanPredictor::_mean(model, fit, features);
144+
}
145+
} // namespace detail
146+
147+
template <typename PredictType, typename ModelType, typename FeatureType,
148+
typename FitType>
149+
auto make_prediction(
150+
const ModelType &model, const FitType &fit,
151+
const std::vector<FeatureType> &features,
152+
PredictTypeIdentity<PredictType> = PredictTypeIdentity<PredictType>()) {
153+
return detail::make_prediction(model, fit, features,
154+
PredictTypeIdentity<PredictType>());
155+
}
156+
113157
template <typename ModelType, typename FeatureType, typename FitType>
114158
class Prediction {
115159

@@ -126,62 +170,30 @@ class Prediction {
126170
: model_(std::move(model)), fit_(std::move(fit)), features_(features) {}
127171

128172
// Mean
129-
template <typename DummyType = FeatureType,
130-
typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
131-
DummyType, FitType>::value,
132-
int>::type = 0>
133-
Eigen::VectorXd mean() const {
173+
template <typename DummyType = FeatureType> Eigen::VectorXd mean() const {
134174
static_assert(std::is_same<DummyType, FeatureType>::value,
135175
"never do prediction.mean<T>()");
136-
return MeanPredictor()._mean(model_, fit_, features_);
176+
return make_prediction(model_, fit_, features_,
177+
PredictTypeIdentity<Eigen::VectorXd>());
137178
}
138179

139-
template <
140-
typename DummyType = FeatureType,
141-
typename std::enable_if<!can_predict_mean<MeanPredictor, ModelType,
142-
DummyType, FitType>::value,
143-
int>::type = 0>
144-
Eigen::VectorXd mean() const = delete; // No valid predict method found.
145-
146180
// Marginal
147-
template <
148-
typename DummyType = FeatureType,
149-
typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
150-
DummyType, FitType>::value,
151-
int>::type = 0>
181+
template <typename DummyType = FeatureType>
152182
MarginalDistribution marginal() const {
153183
static_assert(std::is_same<DummyType, FeatureType>::value,
154184
"never do prediction.mean<T>()");
155-
return MarginalPredictor()._marginal(model_, fit_, features_);
185+
return make_prediction(model_, fit_, features_,
186+
PredictTypeIdentity<MarginalDistribution>());
156187
}
157188

158-
template <typename DummyType = FeatureType,
159-
typename std::enable_if<
160-
!can_predict_marginal<MarginalPredictor, ModelType, DummyType,
161-
FitType>::value,
162-
int>::type = 0>
163-
MarginalDistribution
164-
marginal() const = delete; // No valid predict method found.
165-
166189
// Joint
167-
template <
168-
typename DummyType = FeatureType,
169-
typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
170-
DummyType, FitType>::value,
171-
int>::type = 0>
172-
JointDistribution joint() const {
190+
template <typename DummyType = FeatureType> JointDistribution joint() const {
173191
static_assert(std::is_same<DummyType, FeatureType>::value,
174192
"never do prediction.mean<T>()");
175-
return JointPredictor()._joint(model_, fit_, features_);
193+
return make_prediction(model_, fit_, features_,
194+
PredictTypeIdentity<JointDistribution>());
176195
}
177196

178-
template <
179-
typename DummyType = FeatureType,
180-
typename std::enable_if<!can_predict_joint<JointPredictor, ModelType,
181-
DummyType, FitType>::value,
182-
int>::type = 0>
183-
JointDistribution joint() const = delete; // No valid predict method found.
184-
185197
template <typename PredictType>
186198
PredictType get(PredictTypeIdentity<PredictType> =
187199
PredictTypeIdentity<PredictType>()) const {

include/albatross/src/core/traits.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,28 +148,30 @@ DEFINE_CLASS_METHOD_TRAITS(_mean);
148148
template <typename T, typename ModelType, typename FeatureType,
149149
typename FitType>
150150
struct can_predict_mean
151-
: public has__mean<T, typename const_ref<ModelType>::type,
152-
typename const_ref<FitType>::type,
153-
typename const_ref<std::vector<FeatureType>>::type> {};
151+
: public has__mean_with_return_type<
152+
T, Eigen::VectorXd, typename const_ref<ModelType>::type,
153+
typename const_ref<FitType>::type,
154+
typename const_ref<std::vector<FeatureType>>::type> {};
154155

155156
DEFINE_CLASS_METHOD_TRAITS(_marginal);
156157

157158
template <typename T, typename ModelType, typename FeatureType,
158159
typename FitType>
159160
struct can_predict_marginal
160-
: public has__marginal<T, typename const_ref<ModelType>::type,
161-
typename const_ref<FitType>::type,
162-
typename const_ref<std::vector<FeatureType>>::type> {
163-
};
161+
: public has__marginal_with_return_type<
162+
T, MarginalDistribution, typename const_ref<ModelType>::type,
163+
typename const_ref<FitType>::type,
164+
typename const_ref<std::vector<FeatureType>>::type> {};
164165

165166
DEFINE_CLASS_METHOD_TRAITS(_joint);
166167

167168
template <typename T, typename ModelType, typename FeatureType,
168169
typename FitType>
169170
struct can_predict_joint
170-
: public has__joint<T, typename const_ref<ModelType>::type,
171-
typename const_ref<FitType>::type,
172-
typename const_ref<std::vector<FeatureType>>::type> {};
171+
: public has__joint_with_return_type<
172+
T, JointDistribution, typename const_ref<ModelType>::type,
173+
typename const_ref<FitType>::type,
174+
typename const_ref<std::vector<FeatureType>>::type> {};
173175

174176
/*
175177
* Methods for inspecting `Prediction` types.

0 commit comments

Comments
 (0)