88import pydantic as pdt
99from sklearn import compose , ensemble , pipeline , preprocessing
1010
11- from bikes import schemas
11+ from bikes . core import schemas
1212
1313# %% TYPES
1414
15+ # Model params
1516ParamKey = str
1617ParamValue = T .Any
1718Params = dict [ParamKey , ParamValue ]
1819
1920# %% MODELS
2021
2122
22- class Model (abc .ABC , pdt .BaseModel , strict = True ):
23- """Base class for a model.
23+ class Model (abc .ABC , pdt .BaseModel , strict = True , frozen = False , extra = "forbid" ):
24+ """Base class for a project model.
2425
2526 Use a model to adapt AI/ML frameworks.
2627 e.g., to swap easily one model with another.
@@ -32,7 +33,7 @@ def get_params(self, deep: bool = True) -> Params:
3233 """Get the model params.
3334
3435 Args:
35- deep (bool, optional): ignored. Defaults to True.
36+ deep (bool, optional): ignored.
3637
3738 Returns:
3839 Params: internal model parameters.
@@ -62,7 +63,7 @@ def fit(self, inputs: schemas.Inputs, targets: schemas.Targets) -> T.Self:
6263 targets (schemas.Targets): model training targets.
6364
6465 Returns:
65- Model : instance of the model.
66+ T.Self : instance of the model.
6667 """
6768
6869 @abc .abstractmethod
@@ -76,11 +77,22 @@ def predict(self, inputs: schemas.Inputs) -> schemas.Outputs:
7677 schemas.Outputs: model prediction outputs.
7778 """
7879
80+ def get_internal_model (self ) -> T .Any :
81+ """Return the internal model in the object.
82+
83+ Raises:
84+ NotImplementedError: method not implemented.
85+
86+ Returns:
87+ T.Any: any internal model (either empty or fitted).
88+ """
89+ raise NotImplementedError ()
90+
7991
8092class BaselineSklearnModel (Model ):
81- """Simple baseline model with sklearn .
93+ """Simple baseline model based on scikit-learn .
8294
83- Attributes :
95+ Parameters :
8496 max_depth (int): maximum depth of the random forest.
8597 n_estimators (int): number of estimators in the random forest.
8698 random_state (int, optional): random state of the machine learning pipeline.
@@ -142,12 +154,19 @@ def fit(self, inputs: schemas.Inputs, targets: schemas.Targets) -> "BaselineSkle
142154
143155 @T .override
144156 def predict (self , inputs : schemas .Inputs ) -> schemas .Outputs :
145- assert self ._pipeline is not None , "Model should be fitted first!"
146- prediction = self . _pipeline . predict (inputs ) # return an np.ndarray
157+ model = self .get_internal_model ()
158+ prediction = model . predict (inputs )
147159 outputs = schemas .Outputs (
148160 {schemas .OutputsSchema .prediction : prediction }, index = inputs .index
149161 )
150162 return outputs
151163
164+ @T .override
165+ def get_internal_model (self ) -> pipeline .Pipeline :
166+ model = self ._pipeline
167+ if model is None :
168+ raise ValueError ("Model is not fitted yet!" )
169+ return model
170+
152171
153172ModelKind = BaselineSklearnModel
0 commit comments