@@ -63,7 +63,7 @@ def get_model_info(model, X, y=None):
6363
6464 # Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
6565 # name alone is not sufficient.
66- elif torch and isinstance (model , torch .nn .Module ):
66+ if torch and isinstance (model , torch .nn .Module ):
6767 return PyTorchModelInfo (model , X , y )
6868
6969 raise ValueError (f"Unrecognized model type { type (model )} received." )
@@ -200,7 +200,8 @@ class OnnxModelInfo(ModelInfo):
200200 def __init__ (self , model , X , y = None ):
201201 if onnx is None :
202202 raise RuntimeError (
203- "The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
203+ "The onnx package must be installed to work with ONNX models. "
204+ "Please `pip install onnx`."
204205 )
205206
206207 self ._model = model
@@ -214,38 +215,19 @@ def __init__(self, model, X, y=None):
214215
215216 if len (inputs ) > 1 :
216217 warnings .warn (
217- f"The ONNX model has { len (inputs )} inputs but only the first input will be captured in Model Manager."
218+ f"The ONNX model has { len (inputs )} inputs but only the first input "
219+ f"will be captured in Model Manager."
218220 )
219221
220222 if len (outputs ) > 1 :
221223 warnings .warn (
222- f"The ONNX model has { len (outputs )} outputs but only the first input will be captured in Model Manager."
224+ f"The ONNX model has { len (outputs )} outputs but only the first output "
225+ f"will be captured in Model Manager."
223226 )
224227
225228 self ._X_df = inputs [0 ]
226229 self ._y_df = outputs [0 ]
227230
228- # initializer (static params)
229-
230- # for field in model.ListFields():
231- # doc_string
232- # domain
233- # metadata_props
234- # model_author
235- # model_license
236- # model_version
237- # producer_name
238- # producer_version
239- # training_info
240-
241- # irVersion
242- # producerName
243- # producerVersion
244- # opsetImport
245-
246- # # list of (FieldDescriptor, value)
247- # fields = model.ListFields()
248-
249231 @staticmethod
250232 def _tensor_to_dataframe (tensor ):
251233 """
@@ -272,7 +254,7 @@ def _tensor_to_dataframe(tensor):
272254 name = tensor .get ("name" , "Var" )
273255 type_ = tensor ["type" ]
274256
275- if not "tensorType" in type_ :
257+ if "tensorType" not in type_ :
276258 raise ValueError (f"Received an unexpected ONNX input type: { type_ } ." )
277259
278260 dtype = onnx .helper .tensor_dtype_to_np_dtype (type_ ["tensorType" ]["elemType" ])
@@ -374,8 +356,6 @@ def __init__(self, model, X, y=None):
374356 raise ValueError (
375357 f"Expected input data to be a numpy array or PyTorch tensor, received { type (X )} ."
376358 )
377- # if X.ndim != 2:
378- # raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
379359
380360 # Ensure each input is a PyTorch Tensor
381361 X = tuple (x if isinstance (x , torch .Tensor ) else torch .tensor (x ) for x in X )
@@ -395,8 +375,6 @@ def __init__(self, model, X, y=None):
395375 )
396376
397377 self ._model = model
398-
399- # TODO: convert X and y to DF with arbitrary names
400378 self ._X = X
401379 self ._y = y
402380
0 commit comments