Skip to content

Commit cbdfb65

Browse files
committed
added score_py_uri in prepare() to accept custom scorepy
1 parent 48b5254 commit cbdfb65

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

ads/model/generic_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def prepare(
789789
ignore_pending_changes: bool = True,
790790
max_col_num: int = DATA_SCHEMA_MAX_COL_NUM,
791791
ignore_conda_error: bool = False,
792+
score_py_uri: str = None,
792793
**kwargs: Dict,
793794
) -> "GenericModel":
794795
"""Prepare and save the score.py, serialized model and runtime.yaml file.
@@ -841,6 +842,8 @@ def prepare(
841842
number of features(columns).
842843
ignore_conda_error: (bool, optional). Defaults to False.
843844
Parameter to ignore error when collecting conda information.
845+
score_py_uri: (str, optional). Defaults to False.
846+
The uri of the customized score.py. The provided score.py will be added into artifact_dir.
844847
kwargs:
845848
impute_values: (dict, optional).
846849
The dictionary where the key is the column index(or names is accepted
@@ -1001,13 +1004,21 @@ def prepare(
10011004
jinja_template_filename = (
10021005
"score-pkl" if self._serialize else "score_generic"
10031006
)
1004-
self.model_artifact.prepare_score_py(
1005-
jinja_template_filename=jinja_template_filename,
1006-
model_file_name=self.model_file_name,
1007-
data_deserializer=self.model_input_serializer.name,
1008-
model_serializer=self.model_save_serializer.name,
1009-
**{**kwargs, **self._score_args},
1010-
)
1007+
1008+
if score_py_uri:
1009+
utils.copy_file(
1010+
uri_src=score_py_uri,
1011+
uri_dst=self.artifact_dir,
1012+
force_overwrite=force_overwrite,auth=self.auth
1013+
)
1014+
else:
1015+
self.model_artifact.prepare_score_py(
1016+
jinja_template_filename=jinja_template_filename,
1017+
model_file_name=self.model_file_name,
1018+
data_deserializer=self.model_input_serializer.name,
1019+
model_serializer=self.model_save_serializer.name,
1020+
**{**kwargs, **self._score_args},
1021+
)
10111022

10121023
self._summary_status.update_status(
10131024
detail="Generated score.py", status=ModelState.DONE.value

0 commit comments

Comments
 (0)