Skip to content

Commit e037c6b

Browse files
committed
resolved conflict
2 parents 4206cc6 + 78fb1e7 commit e037c6b

File tree

5 files changed

+122
-8
lines changed

5 files changed

+122
-8
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: Feature Request
2+
description: Feature and enhancement proposals in oracle-ads library
3+
title: "[FR]: "
4+
labels: [Task, Backlog]
5+
assignees:
6+
- octocat
7+
body:
8+
- type: markdown
9+
attributes:
10+
value: |
11+
Before proceeding, please review the [Contributing to this repository](https://github.com/oracle/accelerated-data-science/blob/main/CONTRIBUTING.md) and the [Code of Conduct](https://github.com/oracle/.github/blob/main/CODE_OF_CONDUCT.md).
12+
13+
---
14+
15+
Thank you for submitting a feature request.
16+
- type: dropdown
17+
id: contribution
18+
attributes:
19+
label: Willingness to contribute
20+
description: Would you or another member of your organization be willing to contribute an implementation of this feature?
21+
options:
22+
- Yes. I can contribute this feature independently.
23+
- Yes. I would be willing to contribute this feature with guidance from the oracle-ads team.
24+
- No. I cannot contribute this feature at this time.
25+
validations:
26+
required: true
27+
- type: textarea
28+
attributes:
29+
label: Proposal Summary
30+
description: |
31+
In a few sentences, provide a clear, high-level description of the feature request
32+
validations:
33+
required: true
34+
- type: textarea
35+
attributes:
36+
label: Motivation
37+
description: |
38+
- What is the use case for this feature?
39+
- Why is this use case valuable to support for OCI DataScience users in general?
40+
- Why is this use case valuable to support for your project(s) or organization?
41+
- Why is it currently difficult to achieve this use case?
42+
value: |
43+
> #### What is the use case for this feature?
44+
45+
> #### Why is this use case valuable to support for OCI DataScience users in general?
46+
47+
> #### Why is this use case valuable to support for your project(s) or organization?
48+
49+
> #### Why is it currently difficult to achieve this use case?
50+
validations:
51+
required: true
52+
- type: textarea
53+
attributes:
54+
label: Details
55+
description: |
56+
Use this section to include any additional information about the feature. If you have a proposal for how to implement this feature, please include it here. For implementation guidelines, please refer to the [Contributing to this repository](https://github.com/oracle/accelerated-data-science/blob/main/CONTRIBUTING.md).
57+
validations:
58+
required: false

ads/model/generic_model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def prepare(
789789
ignore_pending_changes: bool = True,
790790
max_col_num: int = DATA_SCHEMA_MAX_COL_NUM,
791791
ignore_conda_error: bool = False,
792+
score_py_uri: str = None,
792793
**kwargs: Dict,
793794
) -> "GenericModel":
794795
"""Prepare and save the score.py, serialized model and runtime.yaml file.
@@ -841,6 +842,10 @@ def prepare(
841842
number of features(columns).
842843
ignore_conda_error: (bool, optional). Defaults to False.
843844
Parameter to ignore error when collecting conda information.
845+
score_py_uri: (str, optional). Defaults to None.
846+
The uri of the customized score.py, which can be local path or OCI object storage URI.
847+
When provide with this attibute, the `score.py` will not be auto generated, and the
848+
provided `score.py` will be added into artifact_dir.
844849
kwargs:
845850
impute_values: (dict, optional).
846851
The dictionary where the key is the column index(or names is accepted
@@ -1001,13 +1006,22 @@ def prepare(
10011006
jinja_template_filename = (
10021007
"score-pkl" if self._serialize else "score_generic"
10031008
)
1004-
self.model_artifact.prepare_score_py(
1005-
jinja_template_filename=jinja_template_filename,
1006-
model_file_name=self.model_file_name,
1007-
data_deserializer=self.model_input_serializer.name,
1008-
model_serializer=self.model_save_serializer.name,
1009-
**{**kwargs, **self._score_args},
1010-
)
1009+
1010+
if score_py_uri:
1011+
utils.copy_file(
1012+
uri_src=score_py_uri,
1013+
uri_dst=os.path.join(self.artifact_dir, "score.py"),
1014+
force_overwrite=force_overwrite,
1015+
auth=self.auth
1016+
)
1017+
else:
1018+
self.model_artifact.prepare_score_py(
1019+
jinja_template_filename=jinja_template_filename,
1020+
model_file_name=self.model_file_name,
1021+
data_deserializer=self.model_input_serializer.name,
1022+
model_serializer=self.model_save_serializer.name,
1023+
**{**kwargs, **self._score_args},
1024+
)
10111025

10121026
self._summary_status.update_status(
10131027
detail="Generated score.py", status=ModelState.DONE.value

docs/source/user_guide/model_registration/model_artifact.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Auto generation of ``score.py`` with framework specific code for loading models
3030

3131
To accomodate for other frameworks that are unknown to ADS, a template code for ``score.py`` is generated in the provided artificat directory location.
3232

33+
3334
Prepare the Model Artifact
3435
--------------------------
3536

@@ -98,8 +99,25 @@ ADS automatically captures:
9899
* ``UseCaseType`` in ``metadata_taxonomy`` cannot be automatically populated. One way to populate the use case is to pass ``use_case_type`` to the ``prepare`` method.
99100
* Model introspection is automatically triggered.
100101

101-
.. include:: _template/score.rst
102+
Prepare with custom ``score.py``
103+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104+
105+
.. versionadded:: 2.8.4
102106

107+
You could provide the location of your own ``score.py`` by ``score_py_uri`` in :py:meth:`~ads.model.GenericModel.prepare`.
108+
The provided ``score.py`` will be added into model artifact.
109+
110+
.. code-block:: python3
111+
112+
tf_model.prepare(
113+
inference_conda_env="generalml_p38_cpu_v1",
114+
use_case_type=UseCaseType.MULTINOMIAL_CLASSIFICATION,
115+
X_sample=trainx,
116+
y_sample=trainy,
117+
score_py_uri="/path/to/score.py"
118+
)
119+
120+
.. include:: _template/score.rst
103121

104122
Model Introspection
105123
-------------------
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# THIS IS A CUSTOM SCORE.PY
2+
3+
model_name = "model.pkl"
4+
5+
6+
def load_model(model_file_name=model_name):
7+
return model_file_name
8+
9+
10+
def predict(data, model=load_model()):
11+
return {"prediction": "This is a custom score.py."}

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,19 @@ def test_prepare_both_conda_env(self, mock_signer):
313313
== "3.7"
314314
)
315315

316+
@patch("ads.common.auth.default_signer")
317+
def test_prepare_with_custom_scorepy(self, mock_signer):
318+
"""Test prepare a trained model with custom score.py."""
319+
self.generic_model.prepare(
320+
INFERENCE_CONDA_ENV,
321+
model_file_name="fake_model_name",
322+
score_py_uri=f"{os.path.dirname(os.path.abspath(__file__))}/test_files/custom_score.py",
323+
)
324+
assert os.path.exists(os.path.join("fake_folder", "score.py"))
325+
326+
prediction = self.generic_model.verify(data="test")["prediction"]
327+
assert prediction == "This is a custom score.py."
328+
316329
@patch("ads.common.auth.default_signer")
317330
def test_verify_without_reload(self, mock_signer):
318331
"""Test verify input data without reload artifacts."""

0 commit comments

Comments
 (0)