Skip to content

Commit 6368348

Browse files
committed
added tests
1 parent 6eb17f6 commit 6368348

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# score.py 1.0 generated by ADS 2.8.3-test-custom-scorepy on 20230331_192511
2+
# THIS IS A CUSTOM SCORE.PY
3+
import json
4+
import os
5+
import cloudpickle
6+
import pandas as pd
7+
import numpy as np
8+
from functools import lru_cache
9+
10+
11+
model_name = 'model.pkl'
12+
13+
14+
"""
15+
Inference script. This script is used for prediction by scoring server when schema is known.
16+
"""
17+
18+
19+
@lru_cache(maxsize=10)
20+
def load_model(model_file_name=model_name):
21+
"""
22+
Loads model from the serialized format
23+
24+
Returns
25+
-------
26+
model: a model instance on which predict API can be invoked
27+
"""
28+
return model_file_name
29+
30+
def pre_inference(data):
31+
"""
32+
Preprocess data
33+
34+
Parameters
35+
----------
36+
data: Data format as expected by the predict API of the core estimator.
37+
38+
Returns
39+
-------
40+
data: Data format after any processing.
41+
42+
"""
43+
return data
44+
45+
46+
def post_inference(yhat):
47+
"""
48+
Post-process the model results
49+
50+
Parameters
51+
----------
52+
yhat: Data format after calling model.predict.
53+
54+
Returns
55+
-------
56+
yhat: Data format after any processing.
57+
58+
"""
59+
return yhat
60+
61+
def predict(data, model=load_model()):
62+
"""
63+
Returns prediction given the model and data to predict
64+
65+
Parameters
66+
----------
67+
model: Model instance returned by load_model API
68+
data: Data format as expected by the predict API of the core estimator. For eg. in case of sckit models it could be numpy array/List of list/Pandas DataFrame
69+
70+
Returns
71+
-------
72+
predictions: Output from scoring server
73+
Format: {'prediction': output from model.predict method}
74+
75+
"""
76+
return {'prediction': "This is a custom score.py."}

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,19 @@ def test_prepare_both_conda_env(self, mock_signer):
297297
== "3.7"
298298
)
299299

300+
@patch("ads.common.auth.default_signer")
301+
def test_prepare_with_custom_scorepy(self, mock_signer):
302+
"""Test prepare a trained model with custom score.py."""
303+
self.generic_model.prepare(
304+
"oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
305+
model_file_name="fake_model_name",
306+
score_py_uri=f"{os.path.dirname(os.path.abspath(__file__))}/test_files/custom_score.py"
307+
)
308+
assert os.path.exists(os.path.join("fake_folder", "score.py"))
309+
310+
prediction = self.generic_model.verify(data="test")["prediction"]
311+
assert prediction == "This is a custom score.py."
312+
300313
@patch("ads.common.auth.default_signer")
301314
def test_verify_without_reload(self, mock_signer):
302315
"""Test verify input data without reload artifacts."""

0 commit comments

Comments
 (0)