Skip to content

Commit 728c816

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
added torch example
1 parent a88caec commit 728c816

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

unboxapi/models.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -239,28 +239,20 @@ def pytorch(self) -> str:
239239
Examples
240240
--------
241241
.. seealso::
242-
Our `sample notebooks <https://github.com/unboxai/unboxapi-python-client/tree/cid/api-docs-improvements/examples/text-classification/tensorflow>`_ and
242+
Our `sample notebooks <https://github.com/unboxai/unboxapi-python-client/tree/cid/api-docs-improvements/examples/text-classification/pytorch>`_ and
243243
`tutorials <https://unbox.readme.io/docs/overview-of-tutorial-tracks>`_.
244244
245-
Let's say you have trained a ``torch`` model that performs text classification. Your training pipeline might look like this:
246-
247-
>>> import tensorflow as tf
248-
>>> from tensorflow import keras
249-
>>>
250-
>>> model.compile(optimizer='adam',
251-
... loss='binary_crossentropy',
252-
... metrics=['accuracy'])
253-
>>>
254-
>>> model.fit(X_train, y_train, epochs=30, batch_size=512)
245+
Let's say you have trained a ``torch`` model that performs text classification.
255246
256247
You must next define a ``predict_proba`` function that adheres to the signature defined below.
257248
258249
**If your task type is text classification...**
259250
260251
>>> def predict_proba(model, text_list: List[str], **kwargs):
261-
... # Optional pre-processing of text_list
262-
... preds = model(text_list)
263-
... # Optional re-weighting of preds
252+
... with torch.no_grad():
253+
... # Optional pre-processing of text_list
254+
... preds = model(text_list)
255+
... # Optional re-weighting of preds
264256
... return preds
265257
266258
The ``model`` arg must be the actual trained model object, and the ``text_list`` arg must be a list of
@@ -269,9 +261,10 @@ def pytorch(self) -> str:
269261
**If your task type is tabular classification...**
270262
271263
>>> def predict_proba(model, input_features: np.ndarray, **kwargs):
272-
... # Optional pre-processing of input_features
273-
... preds = model(input_features)
274-
... # Optional re-weighting of preds
264+
... with torch.no_grad():
265+
... # Optional pre-processing of input_features
266+
... preds = model(input_features)
267+
... # Optional re-weighting of preds
275268
... return preds
276269
277270
The ``model`` arg must be the actual trained model object, and the ``input_features`` arg must be a 2D numpy array
@@ -292,11 +285,11 @@ def pytorch(self) -> str:
292285
>>> model = client.add_model(
293286
... function=predict_proba,
294287
... model=model,
295-
... model_type=ModelType.tensorflow,
288+
... model_type=ModelType.pytorch,
296289
... task_type=TaskType.TextClassification,
297290
... class_names=['Negative', 'Positive'],
298-
... name='My Tensorflow model',
299-
... description='this is my tensorflow model',
291+
... name='My Torch model',
292+
... description='this is my torch model',
300293
... )
301294
>>> model.to_dict()
302295
@@ -305,11 +298,11 @@ def pytorch(self) -> str:
305298
>>> model = client.add_model(
306299
... function=predict_proba,
307300
... model=model,
308-
... model_type=ModelType.tensorflow,
301+
... model_type=ModelType.pytorch,
309302
... task_type=TaskType.TabularClassification,
310303
... class_names=['Exited', 'Retained'],
311-
... name='My Tensorflow model',
312-
... description='this is my tensorflow model',
304+
... name='My Torch model',
305+
... description='this is my torch model',
313306
... )
314307
>>> model.to_dict()
315308
"""

0 commit comments

Comments
 (0)