Skip to content

Commit 78fb1e7

Browse files
authored
ODSC-39738: allow for attaching customized scorepy in prepare (#133)
2 parents 5887287 + d857985 commit 78fb1e7

File tree

4 files changed

+67
-11
lines changed

4 files changed

+67
-11
lines changed

ads/model/generic_model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def prepare(
789789
ignore_pending_changes: bool = True,
790790
max_col_num: int = DATA_SCHEMA_MAX_COL_NUM,
791791
ignore_conda_error: bool = False,
792+
score_py_uri: str = None,
792793
**kwargs: Dict,
793794
) -> "GenericModel":
794795
"""Prepare and save the score.py, serialized model and runtime.yaml file.
@@ -841,6 +842,10 @@ def prepare(
841842
number of features(columns).
842843
ignore_conda_error: (bool, optional). Defaults to False.
843844
Parameter to ignore error when collecting conda information.
845+
score_py_uri: (str, optional). Defaults to None.
846+
The uri of the customized score.py, which can be local path or OCI object storage URI.
847+
When provide with this attibute, the `score.py` will not be auto generated, and the
848+
provided `score.py` will be added into artifact_dir.
844849
kwargs:
845850
impute_values: (dict, optional).
846851
The dictionary where the key is the column index(or names is accepted
@@ -1001,13 +1006,22 @@ def prepare(
10011006
jinja_template_filename = (
10021007
"score-pkl" if self._serialize else "score_generic"
10031008
)
1004-
self.model_artifact.prepare_score_py(
1005-
jinja_template_filename=jinja_template_filename,
1006-
model_file_name=self.model_file_name,
1007-
data_deserializer=self.model_input_serializer.name,
1008-
model_serializer=self.model_save_serializer.name,
1009-
**{**kwargs, **self._score_args},
1010-
)
1009+
1010+
if score_py_uri:
1011+
utils.copy_file(
1012+
uri_src=score_py_uri,
1013+
uri_dst=os.path.join(self.artifact_dir, "score.py"),
1014+
force_overwrite=force_overwrite,
1015+
auth=self.auth
1016+
)
1017+
else:
1018+
self.model_artifact.prepare_score_py(
1019+
jinja_template_filename=jinja_template_filename,
1020+
model_file_name=self.model_file_name,
1021+
data_deserializer=self.model_input_serializer.name,
1022+
model_serializer=self.model_save_serializer.name,
1023+
**{**kwargs, **self._score_args},
1024+
)
10111025

10121026
self._summary_status.update_status(
10131027
detail="Generated score.py", status=ModelState.DONE.value

docs/source/user_guide/model_registration/model_artifact.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Auto generation of ``score.py`` with framework specific code for loading models
3030

3131
To accomodate for other frameworks that are unknown to ADS, a template code for ``score.py`` is generated in the provided artificat directory location.
3232

33+
3334
Prepare the Model Artifact
3435
--------------------------
3536

@@ -98,8 +99,25 @@ ADS automatically captures:
9899
* ``UseCaseType`` in ``metadata_taxonomy`` cannot be automatically populated. One way to populate the use case is to pass ``use_case_type`` to the ``prepare`` method.
99100
* Model introspection is automatically triggered.
100101

101-
.. include:: _template/score.rst
102+
Prepare with custom ``score.py``
103+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104+
105+
.. versionadded:: 2.8.4
102106

107+
You could provide the location of your own ``score.py`` by ``score_py_uri`` in :py:meth:`~ads.model.GenericModel.prepare`.
108+
The provided ``score.py`` will be added into model artifact.
109+
110+
.. code-block:: python3
111+
112+
tf_model.prepare(
113+
inference_conda_env="generalml_p38_cpu_v1",
114+
use_case_type=UseCaseType.MULTINOMIAL_CLASSIFICATION,
115+
X_sample=trainx,
116+
y_sample=trainy,
117+
score_py_uri="/path/to/score.py"
118+
)
119+
120+
.. include:: _template/score.rst
103121

104122
Model Introspection
105123
-------------------
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# THIS IS A CUSTOM SCORE.PY
2+
3+
model_name = "model.pkl"
4+
5+
6+
def load_model(model_file_name=model_name):
7+
return model_file_name
8+
9+
10+
def predict(data, model=load_model()):
11+
return {"prediction": "This is a custom score.py."}

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,16 @@
168168
"training_script": None,
169169
}
170170

171+
INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172+
TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173+
171174

172175
class TestEstimator:
173176
def predict(self, x):
174177
return x**2
175178

176179

177180
class TestGenericModel:
178-
179181
iris = load_iris()
180182
X, y = iris.data, iris.target
181183
X_train, X_test, y_train, y_test = train_test_split(X, y)
@@ -297,6 +299,19 @@ def test_prepare_both_conda_env(self, mock_signer):
297299
== "3.7"
298300
)
299301

302+
@patch("ads.common.auth.default_signer")
303+
def test_prepare_with_custom_scorepy(self, mock_signer):
304+
"""Test prepare a trained model with custom score.py."""
305+
self.generic_model.prepare(
306+
INFERENCE_CONDA_ENV,
307+
model_file_name="fake_model_name",
308+
score_py_uri=f"{os.path.dirname(os.path.abspath(__file__))}/test_files/custom_score.py",
309+
)
310+
assert os.path.exists(os.path.join("fake_folder", "score.py"))
311+
312+
prediction = self.generic_model.verify(data="test")["prediction"]
313+
assert prediction == "This is a custom score.py."
314+
300315
@patch("ads.common.auth.default_signer")
301316
def test_verify_without_reload(self, mock_signer):
302317
"""Test verify input data without reload artifacts."""
@@ -795,7 +810,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
795810
with patch.object(
796811
GenericModel, "get_data_serializer"
797812
) as mock_get_data_serializer:
798-
799813
mock_get_data_serializer.return_value.data = df.to_json()
800814
mock_state.return_value = ModelDeploymentState.ACTIVE
801815
with patch.object(ModelDeployment, "predict") as mock_predict:
@@ -1782,7 +1796,6 @@ def test_upload_artifact_fail(self):
17821796
def test_upload_artifact_success(self):
17831797
"""Tests uploading model artifacts to the provided `uri`."""
17841798
with tempfile.TemporaryDirectory() as tmp_dir:
1785-
17861799
# copy test artifacts to the temp folder
17871800
shutil.copytree(
17881801
os.path.join(self.curr_dir, "test_files/valid_model_artifacts"),

0 commit comments

Comments
 (0)