Skip to content

Commit a88caec

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
added keras example
1 parent 5747439 commit a88caec

File tree

1 file changed

+78
-1
lines changed

1 file changed

+78
-1
lines changed

unboxapi/models.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,84 @@ def transformers(self) -> str:
468468

469469
@property
470470
def keras(self) -> str:
471-
"""For models built with `Keras <https://keras.io/>`_."""
471+
"""For models built with `Keras <https://keras.io/>`_.
472+
473+
Examples
474+
--------
475+
.. seealso::
476+
Our `sample notebooks <https://github.com/unboxai/unboxapi-python-client/tree/cid/api-docs-improvements/examples/text-classification/tensorflow>`_ and
477+
`tutorials <https://unbox.readme.io/docs/overview-of-tutorial-tracks>`_.
478+
479+
Let's say you have trained a ``keras`` binary classifier. Your training pipeline might look like this:
480+
481+
>>> from tensorflow import keras
482+
>>>
483+
>>> model.compile(optimizer='adam',
484+
... loss='binary_crossentropy',
485+
... metrics=['accuracy'])
486+
>>>
487+
>>> model.fit(X_train, y_train, epochs=30, batch_size=512)
488+
489+
You must next define a ``predict_proba`` function that adheres to the signature defined below.
490+
491+
**If your task type is text classification...**
492+
493+
>>> def predict_proba(model, text_list: List[str], **kwargs):
494+
... # Optional pre-processing of text_list
495+
... preds = model(text_list)
496+
... # Optional re-weighting of preds
497+
... return preds
498+
499+
The ``model`` arg must be the actual trained model object, and the ``text_list`` arg must be a list of
500+
strings.
501+
502+
**If your task type is tabular classification...**
503+
504+
>>> def predict_proba(model, input_features: np.ndarray, **kwargs):
505+
... # Optional pre-processing of input_features
506+
... preds = model(input_features)
507+
... # Optional re-weighting of preds
508+
... return preds
509+
510+
The ``model`` arg must be the actual trained model object, and the ``input_features`` arg must be a 2D numpy array
511+
containing a batch of features that will be passed to the model as inputs.
512+
513+
On both cases, you can optionally include other kwargs in the function, including tokenizers, variables, encoders etc.
514+
You simply pass those kwargs to the :meth:`unboxapi.UnboxClient.add_model` function call when you upload the model.
515+
516+
To upload the model to Unbox, first instantiate the client
517+
518+
>>> import unboxapi
519+
>>> client = unboxapi.UnboxClient('YOUR_API_KEY_HERE')
520+
521+
Now, you can use the ``client.add_model()`` method:
522+
523+
**If your task type is text classification...**
524+
525+
>>> model = client.add_model(
526+
... function=predict_proba,
527+
... model=model,
528+
... model_type=ModelType.keras,
529+
... task_type=TaskType.TextClassification,
530+
... class_names=['Negative', 'Positive'],
531+
... name='My Keras model',
532+
... description='this is my keras model',
533+
... )
534+
>>> model.to_dict()
535+
536+
**If your task type is tabular classification...**
537+
538+
>>> model = client.add_model(
539+
... function=predict_proba,
540+
... model=model,
541+
... model_type=ModelType.keras,
542+
... task_type=TaskType.TabularClassification,
543+
... class_names=['Exited', 'Retained'],
544+
... name='My keras model',
545+
... description='this is my keras model',
546+
... )
547+
>>> model.to_dict()
548+
"""
472549
return "KerasModelArtifact"
473550

474551
@property

0 commit comments

Comments
 (0)