Skip to content

Commit 4842b96

Browse files
1vndeliahu
authored andcommitted
Add Iris TensorFlow example (#208)
1 parent ed79e34 commit 4842b96

File tree

4 files changed

+128
-3
lines changed

4 files changed

+128
-3
lines changed

docs/apis/packaging-models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Zip the exported estimator output in your checkpoint directory:
66

77
```text
8-
$ ls export/estimator
8+
$ ls export/estimator/1560263597/
99
saved_model.pb variables/
1010
1111
$ zip -r model.zip export/estimator

examples/iris/tensorflow/model.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# sources copied/modified from https://github.com/tensorflow/models/blob/master/samples/core/get_started/
2+
3+
import tensorflow as tf
4+
from sklearn.datasets import load_iris
5+
from sklearn.model_selection import train_test_split
6+
import shutil
7+
import os
8+
9+
EXPORT_DIR = "iris_tf_export"
10+
11+
12+
def input_fn(features, labels, batch_size, mode):
13+
"""An input function for training"""
14+
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
15+
if mode == tf.estimator.ModeKeys.TRAIN:
16+
dataset = dataset.shuffle(1000).repeat()
17+
dataset = dataset.batch(batch_size)
18+
dataset_it = dataset.make_one_shot_iterator()
19+
irises, labels = dataset_it.get_next()
20+
return {"irises": irises}, labels
21+
22+
23+
def json_serving_input_fn():
24+
inputs = tf.placeholder(shape=[4], dtype=tf.float64)
25+
features = {"irises": tf.expand_dims(inputs, 0)}
26+
return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs)
27+
28+
29+
def my_model(features, labels, mode, params):
30+
"""DNN with three hidden layers and learning_rate=0.1."""
31+
net = features["irises"]
32+
for units in params["hidden_units"]:
33+
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
34+
35+
logits = tf.layers.dense(net, params["n_classes"], activation=None)
36+
37+
predicted_classes = tf.argmax(logits, 1)
38+
if mode == tf.estimator.ModeKeys.PREDICT:
39+
predictions = {
40+
"class_ids": predicted_classes[:, tf.newaxis],
41+
"probabilities": tf.nn.softmax(logits),
42+
"logits": logits,
43+
}
44+
return tf.estimator.EstimatorSpec(
45+
mode=mode,
46+
predictions=predictions,
47+
export_outputs={
48+
"predict": tf.estimator.export.PredictOutput(
49+
{
50+
"class_ids": predicted_classes[:, tf.newaxis],
51+
"probabilities": tf.nn.softmax(logits),
52+
}
53+
)
54+
},
55+
)
56+
57+
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
58+
59+
accuracy = tf.metrics.accuracy(labels=labels, predictions=predicted_classes, name="acc_op")
60+
metrics = {"accuracy": accuracy}
61+
tf.summary.scalar("accuracy", accuracy[1])
62+
63+
if mode == tf.estimator.ModeKeys.EVAL:
64+
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
65+
66+
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
67+
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
68+
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
69+
70+
71+
iris = load_iris()
72+
X, y = iris.data, iris.target
73+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
74+
75+
classifier = tf.estimator.Estimator(
76+
model_fn=my_model, model_dir=EXPORT_DIR, params={"hidden_units": [10, 10], "n_classes": 3}
77+
)
78+
79+
80+
train_input_fn = lambda: input_fn(X_train, y_train, 100, tf.estimator.ModeKeys.TRAIN)
81+
eval_input_fn = lambda: input_fn(X_test, y_test, 100, tf.estimator.ModeKeys.EVAL)
82+
serving_input_fn = lambda: json_serving_input_fn()
83+
exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False)
84+
train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=1000)
85+
eval_spec = tf.estimator.EvalSpec(eval_input_fn, exporters=[exporter], name="estimator-eval")
86+
87+
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
88+
89+
# zip the estimator export dir (the exported path looks like iris_tf_export/export/estimator/1562353043/)
90+
estimator_dir = EXPORT_DIR + "/export/estimator"
91+
shutil.make_archive("tensorflow", "zip", os.path.join(estimator_dir))
92+
93+
# clean up
94+
shutil.rmtree(EXPORT_DIR)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
tensorflow
2+
sklearn

pkg/workloads/tf_api/api.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,30 @@ def predict(deployment_name, api_name):
391391
return jsonify(response)
392392

393393

394+
def validate_model_dir(model_dir):
395+
"""
396+
validates that model_dir has the expected directory tree.
397+
398+
For example (your TF serving version number may be different):
399+
400+
1562353043/
401+
saved_model.pb
402+
variables/
403+
variables.data-00000-of-00001
404+
variables.index
405+
"""
406+
version = os.listdir(model_dir)[0]
407+
if not version.isdigit():
408+
raise UserException(
409+
"No versions of servable default found under base path in model_dir. See docs.cortex.dev for how to properly package your TensorFlow model"
410+
)
411+
412+
if "saved_model.pb" not in os.listdir(os.path.join(model_dir, version)):
413+
raise UserException(
414+
'Expected packaged model to have a "saved_model.pb" file. See docs.cortex.dev for how to properly package your TensorFlow model'
415+
)
416+
417+
394418
def start(args):
395419
ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id)
396420

@@ -406,8 +430,7 @@ def start(args):
406430
package.install_packages(ctx.python_packages, ctx.storage)
407431
if not os.path.isdir(args.model_dir):
408432
ctx.storage.download_and_unzip_external(api["model"], args.model_dir)
409-
410-
if util.is_resource_ref(api["model"]):
433+
else:
411434
package.install_packages(ctx.python_packages, ctx.storage)
412435
model_name = util.get_resource_ref(api["model"])
413436
model = ctx.models[model_name]
@@ -446,6 +469,12 @@ def start(args):
446469
model["input"]["target_vocab"], None, False
447470
)
448471

472+
try:
473+
validate_model_dir(args.model_dir)
474+
except Exception as e:
475+
logger.exception(e)
476+
sys.exit(1)
477+
449478
channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
450479
local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel)
451480

0 commit comments

Comments
 (0)