@@ -166,36 +166,60 @@ def create_api(
166166 if not inspect .isclass (predictor ):
167167 raise ValueError ("`predictor` parameter must be a class definition" )
168168
169- with open (project_dir / "predictor.pickle" , "wb" ) as pickle_file :
170- dill .dump (predictor , pickle_file )
171- if api_spec .get ("predictor" ) is None :
172- api_spec ["predictor" ] = {}
169+ impl_rel_path = self ._save_impl (predictor , project_dir , "predictor" )
170+ if api_spec .get ("predictor" ) is None :
171+ api_spec ["predictor" ] = {}
173172
174- if predictor .__name__ == "PythonPredictor" :
175- predictor_type = "python"
176- if predictor .__name__ == "TensorFlowPredictor" :
177- predictor_type = "tensorflow"
178- if predictor .__name__ == "ONNXPredictor" :
179- predictor_type = "onnx"
173+ if predictor .__name__ == "PythonPredictor" :
174+ predictor_type = "python"
175+ if predictor .__name__ == "TensorFlowPredictor" :
176+ predictor_type = "tensorflow"
177+ if predictor .__name__ == "ONNXPredictor" :
178+ predictor_type = "onnx"
180179
181- api_spec ["predictor" ]["path" ] = "predictor.pickle"
182- api_spec ["predictor" ]["type" ] = predictor_type
180+ api_spec ["predictor" ]["path" ] = impl_rel_path
181+ api_spec ["predictor" ]["type" ] = predictor_type
183182
184183 if api_kind == "TaskAPI" :
185184 if not callable (task ):
186185 raise ValueError (
187186 "`task` parameter must be a callable (e.g. a function definition or a class definition called `Task` with a `__call__` method implemented"
188187 )
189- with open ( project_dir / "task.pickle" , "wb" ) as pickle_file :
190- dill . dump (task , pickle_file )
191- if api_spec .get ("definition" ) is None :
192- api_spec ["definition" ] = {}
193- api_spec ["definition" ]["path" ] = "task.pickle"
188+
189+ impl_rel_path = self . _save_impl (task , project_dir , "task" )
190+ if api_spec .get ("definition" ) is None :
191+ api_spec ["definition" ] = {}
192+ api_spec ["definition" ]["path" ] = impl_rel_path
194193
195194 with open (cortex_yaml_path , "w" ) as f :
196195 yaml .dump ([api_spec ], f ) # write a list
197196 return self ._deploy (cortex_yaml_path , force = force , wait = wait )
198197
198+ def _save_impl (self , impl , project_dir : Path , filename : str ) -> str :
199+ import __main__ as main
200+
201+ is_interactive = not hasattr (main , "__file__" )
202+
203+ if is_interactive and impl .__module__ == "__main__" :
204+ # class is defined in a REPL (e.g. jupyter)
205+ filename += ".pickle"
206+ with open (project_dir / filename , "wb" ) as pickle_file :
207+
208+ dill .dump (impl , pickle_file )
209+ return filename
210+
211+ filename += ".py"
212+ if not is_interactive and impl .__module__ == "__main__" :
213+ # class is defined in the same file as main
214+ with open (project_dir / filename , "w" ) as f :
215+ f .write (dill .source .importable (impl , source = True ))
216+ return filename
217+
218+ if not is_interactive and not impl .__module__ == "__main__" :
219+ # class is imported, copy file containing the class
220+ shutil .copy (inspect .getfile (impl ), project_dir / filename )
221+ return filename
222+
199223 def _deploy (
200224 self ,
201225 config_file : str ,
0 commit comments