@@ -239,28 +239,20 @@ def pytorch(self) -> str:
239239 Examples
240240 --------
241241 .. seealso::
242- Our `sample notebooks <https://github.com/unboxai/unboxapi-python-client/tree/cid/api-docs-improvements/examples/text-classification/tensorflow >`_ and
242+ Our `sample notebooks <https://github.com/unboxai/unboxapi-python-client/tree/cid/api-docs-improvements/examples/text-classification/pytorch >`_ and
243243 `tutorials <https://unbox.readme.io/docs/overview-of-tutorial-tracks>`_.
244244
245- Let's say you have trained a ``torch`` model that performs text classification. Your training pipeline might look like this:
246-
247- >>> import tensorflow as tf
248- >>> from tensorflow import keras
249- >>>
250- >>> model.compile(optimizer='adam',
251- ... loss='binary_crossentropy',
252- ... metrics=['accuracy'])
253- >>>
254- >>> model.fit(X_train, y_train, epochs=30, batch_size=512)
245+ Let's say you have trained a ``torch`` model that performs text classification.
255246
256247 You must next define a ``predict_proba`` function that adheres to the signature defined below.
257248
258249 **If your task type is text classification...**
259250
260251 >>> def predict_proba(model, text_list: List[str], **kwargs):
261- ... # Optional pre-processing of text_list
262- ... preds = model(text_list)
263- ... # Optional re-weighting of preds
252+ ... with torch.no_grad():
253+ ... # Optional pre-processing of text_list
254+ ... preds = model(text_list)
255+ ... # Optional re-weighting of preds
264256 ... return preds
265257
266258 The ``model`` arg must be the actual trained model object, and the ``text_list`` arg must be a list of
@@ -269,9 +261,10 @@ def pytorch(self) -> str:
269261 **If your task type is tabular classification...**
270262
271263 >>> def predict_proba(model, input_features: np.ndarray, **kwargs):
272- ... # Optional pre-processing of input_features
273- ... preds = model(input_features)
274- ... # Optional re-weighting of preds
264+ ... with torch.no_grad():
265+ ... # Optional pre-processing of input_features
266+ ... preds = model(input_features)
267+ ... # Optional re-weighting of preds
275268 ... return preds
276269
277270 The ``model`` arg must be the actual trained model object, and the ``input_features`` arg must be a 2D numpy array
@@ -292,11 +285,11 @@ def pytorch(self) -> str:
292285 >>> model = client.add_model(
293286 ... function=predict_proba,
294287 ... model=model,
295- ... model_type=ModelType.tensorflow ,
288+ ... model_type=ModelType.pytorch ,
296289 ... task_type=TaskType.TextClassification,
297290 ... class_names=['Negative', 'Positive'],
298- ... name='My Tensorflow model',
299- ... description='this is my tensorflow model',
291+ ... name='My Torch model',
292+ ... description='this is my torch model',
300293 ... )
301294 >>> model.to_dict()
302295
@@ -305,11 +298,11 @@ def pytorch(self) -> str:
305298 >>> model = client.add_model(
306299 ... function=predict_proba,
307300 ... model=model,
308- ... model_type=ModelType.tensorflow ,
301+ ... model_type=ModelType.pytorch ,
309302 ... task_type=TaskType.TabularClassification,
310303 ... class_names=['Exited', 'Retained'],
311- ... name='My Tensorflow model',
312- ... description='this is my tensorflow model',
304+ ... name='My Torch model',
305+ ... description='this is my torch model',
313306 ... )
314307 >>> model.to_dict()
315308 """
0 commit comments