Skip to content

Commit ec4d66b

Browse files
authored
Improve Predictor export strategy in the python client (based on runtime and source) (#1883)
1 parent c61e568 commit ec4d66b

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

pkg/cortex/client/cortex/client.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -166,36 +166,60 @@ def create_api(
166166
if not inspect.isclass(predictor):
167167
raise ValueError("`predictor` parameter must be a class definition")
168168

169-
with open(project_dir / "predictor.pickle", "wb") as pickle_file:
170-
dill.dump(predictor, pickle_file)
171-
if api_spec.get("predictor") is None:
172-
api_spec["predictor"] = {}
169+
impl_rel_path = self._save_impl(predictor, project_dir, "predictor")
170+
if api_spec.get("predictor") is None:
171+
api_spec["predictor"] = {}
173172

174-
if predictor.__name__ == "PythonPredictor":
175-
predictor_type = "python"
176-
if predictor.__name__ == "TensorFlowPredictor":
177-
predictor_type = "tensorflow"
178-
if predictor.__name__ == "ONNXPredictor":
179-
predictor_type = "onnx"
173+
if predictor.__name__ == "PythonPredictor":
174+
predictor_type = "python"
175+
if predictor.__name__ == "TensorFlowPredictor":
176+
predictor_type = "tensorflow"
177+
if predictor.__name__ == "ONNXPredictor":
178+
predictor_type = "onnx"
180179

181-
api_spec["predictor"]["path"] = "predictor.pickle"
182-
api_spec["predictor"]["type"] = predictor_type
180+
api_spec["predictor"]["path"] = impl_rel_path
181+
api_spec["predictor"]["type"] = predictor_type
183182

184183
if api_kind == "TaskAPI":
185184
if not callable(task):
186185
raise ValueError(
187186
"`task` parameter must be a callable (e.g. a function definition or a class definition called `Task` with a `__call__` method implemented"
188187
)
189-
with open(project_dir / "task.pickle", "wb") as pickle_file:
190-
dill.dump(task, pickle_file)
191-
if api_spec.get("definition") is None:
192-
api_spec["definition"] = {}
193-
api_spec["definition"]["path"] = "task.pickle"
188+
189+
impl_rel_path = self._save_impl(task, project_dir, "task")
190+
if api_spec.get("definition") is None:
191+
api_spec["definition"] = {}
192+
api_spec["definition"]["path"] = impl_rel_path
194193

195194
with open(cortex_yaml_path, "w") as f:
196195
yaml.dump([api_spec], f) # write a list
197196
return self._deploy(cortex_yaml_path, force=force, wait=wait)
198197

198+
def _save_impl(self, impl, project_dir: Path, filename: str) -> str:
199+
import __main__ as main
200+
201+
is_interactive = not hasattr(main, "__file__")
202+
203+
if is_interactive and impl.__module__ == "__main__":
204+
# class is defined in a REPL (e.g. jupyter)
205+
filename += ".pickle"
206+
with open(project_dir / filename, "wb") as pickle_file:
207+
208+
dill.dump(impl, pickle_file)
209+
return filename
210+
211+
filename += ".py"
212+
if not is_interactive and impl.__module__ == "__main__":
213+
# class is defined in the same file as main
214+
with open(project_dir / filename, "w") as f:
215+
f.write(dill.source.importable(impl, source=True))
216+
return filename
217+
218+
if not is_interactive and not impl.__module__ == "__main__":
219+
# class is imported, copy file containing the class
220+
shutil.copy(inspect.getfile(impl), project_dir / filename)
221+
return filename
222+
199223
def _deploy(
200224
self,
201225
config_file: str,

0 commit comments

Comments
 (0)