Skip to content

Commit b844b28

Browse files
committed
feat: add vectorize method
1 parent 365690a commit b844b28

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| |API| Add :meth:`vectorize` for faster ensemble model inference using :mod:`functorch` (requiring :mod:`torch` version >= 1.13.0) | `@xuyxu <https://github.com/xuyxu>`__
2122
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
2223
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
2324
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__

torchensemble/_base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def get_doc(item):
2929
__doc = {
3030
"model": const.__model_doc,
3131
"seq_model": const.__seq_model_doc,
32-
"tree_ensmeble_model": const.__tree_ensemble_doc,
32+
"tree_ensemble_model": const.__tree_ensemble_doc,
3333
"fit": const.__fit_doc,
3434
"predict": const.__predict_doc,
35+
"vectorize": const.__vectorize_doc,
3536
"set_optimizer": const.__set_optimizer_doc,
3637
"set_scheduler": const.__set_scheduler_doc,
3738
"set_criterion": const.__set_criterion_doc,
@@ -198,6 +199,21 @@ def predict(self, *x):
198199
pred = self.forward(*x_device)
199200
pred = pred.cpu()
200201
return pred
202+
203+
def vectorize(self):
204+
"""Docstrings decorated by downstream ensembles."""
205+
try:
206+
from functorch import combine_state_for_ensemble
207+
except Exception:
208+
msg = (
209+
"Failed to import functorch utils, please make sure the"
210+
" Pytorch version >= 1.13.0."
211+
)
212+
raise RuntimeError(msg)
213+
214+
self.eval()
215+
fmodel, params, buffers = combine_state_for_ensemble(self.estimators_)
216+
return fmodel, params, buffers
201217

202218

203219
class BaseTreeEnsemble(BaseModule):

torchensemble/_constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@
173173
"""
174174

175175

176+
__vectorize_doc = """
177+
Return the vectorization result of the ensemble using functorch.
178+
179+
Returns
180+
-------
181+
fmodel : FunctionalModuleWithBuffers
182+
Functional version of one of the models in the ensemble.
183+
params : tuple
184+
Tuple of stacked model parameters in the ensemble.
185+
buffers : tuple
186+
Tuple of buffers, empty if not exists.
187+
"""
188+
189+
176190
__classification_forward_doc = """
177191
Parameters
178192
----------

torchensemble/voting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def predict(self, *x):
309309

310310

311311
@torchensemble_model_doc(
312-
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
312+
"""Implementation on the NeuralForestClassifier.""", "tree_ensemble_model"
313313
)
314314
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
315315
def __init__(self, voting_strategy="soft", **kwargs):
@@ -561,7 +561,7 @@ def predict(self, *x):
561561

562562

563563
@torchensemble_model_doc(
564-
"""Implementation on the NeuralForestRegressor.""", "tree_ensmeble_model"
564+
"""Implementation on the NeuralForestRegressor.""", "tree_ensemble_model"
565565
)
566566
class NeuralForestRegressor(BaseTreeEnsemble, VotingRegressor):
567567
@torchensemble_model_doc(

0 commit comments

Comments
 (0)