Skip to content

Commit 826309a

Browse files
authored
Tensor2Tensor Example and transform_tensorflow feature (#29)
1 parent aa2e801 commit 826309a

File tree

9 files changed

+144
-45
lines changed

9 files changed

+144
-45
lines changed

docs/applications/implementations/models.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def create_estimator(run_config, model_config):
4848

4949
## Pre-installed Packages
5050

51+
You can import PyPI packages or your own Python packages to help create more complex models. See [Python Packages](../advanced/python-packages.md) for more details.
52+
5153
The following packages have been pre-installed and can be used in your implementations:
5254

5355
```text
@@ -60,3 +62,41 @@ packaging==19.0.0
6062
```
6163

6264
You can install additional PyPI packages and import your own Python packages. See [Python Packages](../advanced/python-packages.md) for more details.
65+
66+
67+
# Tensorflow Transformations
68+
You can preprocess input features and labels to your model by defining a `transform_tensorflow` function. You can define tensor transformations you want to apply to the features and labels tensors before they are passed to the model.
69+
70+
## Implementation
71+
72+
```python
73+
def transform_tensorflow(features, labels, model_config):
74+
"""Define tensor transformations for the feature and label tensors.
75+
76+
Args:
77+
features: A feature dictionary of column names to feature tensors.
78+
79+
labels: The label tensor.
80+
81+
model_config: The Cortex configuration for the model.
82+
Note: nested resources are expanded (e.g. model_config["target_column"])
83+
will be the configuration for the target column, rather than the
84+
name of the target column).
85+
86+
87+
Returns:
88+
features and labels tensors.
89+
"""
90+
return features, labels
91+
```
92+
93+
## Example
94+
95+
```python
96+
import tensorflow as tf
97+
98+
def transform_tensorflow(features, labels, model_config):
99+
hparams = model_config["hparams"]
100+
features["image_pixels"] = tf.reshape(features["image_pixels"], hparams["input_shape"])
101+
return features, labels
102+
```

examples/mnist/implementations/models/basic.py renamed to examples/mnist/implementations/models/custom.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@ def create_estimator(run_config, model_config):
55
hparams = model_config["hparams"]
66

77
def model_fn(features, labels, mode, params):
8-
images = features["image_pixels"]
9-
images = tf.reshape(images, [-1] + hparams["input_shape"])
10-
x = images
11-
8+
x = features["image_pixels"]
129
for i, feature_count in enumerate(hparams["hidden_units"]):
1310
with tf.variable_scope("layer_%d" % i):
1411
if hparams["layer_type"] == "conv":
@@ -55,3 +52,10 @@ def model_fn(features, labels, mode, params):
5552

5653
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
5754
return estimator
55+
56+
57+
def transform_tensorflow(features, labels, model_config):
58+
hparams = model_config["hparams"]
59+
60+
features["image_pixels"] = tf.reshape(features["image_pixels"], hparams["input_shape"])
61+
return features, labels
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import tensorflow as tf
2+
from tensor2tensor.utils import trainer_lib
3+
from tensor2tensor import models # pylint: disable=unused-import
4+
from tensor2tensor import problems # pylint: disable=unused-import
5+
from tensor2tensor.data_generators import problem_hparams
6+
from tensor2tensor.utils import registry
7+
8+
9+
def create_estimator(run_config, model_config):
10+
# t2t expects these keys in run_config
11+
run_config.data_parallelism = None
12+
run_config.t2t_device_info = {"num_async_replicas": 1}
13+
14+
# t2t has its own set of hyperparameters we can use
15+
hparams = trainer_lib.create_hparams("basic_fc_small")
16+
problem = registry.problem("image_mnist")
17+
p_hparams = problem.get_hparams(hparams)
18+
hparams.problem = problem
19+
hparams.problem_hparams = p_hparams
20+
21+
# don't need eval_metrics
22+
problem.eval_metrics = lambda: []
23+
24+
# t2t expects this key
25+
hparams.warm_start_from = None
26+
27+
estimator = trainer_lib.create_estimator("basic_fc_relu", hparams, run_config)
28+
return estimator
29+
30+
31+
def transform_tensorflow(features, labels, model_config):
32+
hparams = model_config["hparams"]
33+
34+
# t2t model performs flattening and expects this input key
35+
features["inputs"] = tf.reshape(features["image_pixels"], hparams["input_shape"])
36+
37+
# t2t expects this key and dimensionality
38+
features["targets"] = tf.expand_dims(labels, 0)
39+
40+
return features, labels

examples/mnist/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pillow==5.4.1
2+
tensor2tensor==1.10.0

examples/mnist/resources/apis.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
- kind: api
2-
name: dense-classifier
3-
model_name: dense
2+
name: dnn-classifier
3+
model_name: dnn
44
compute:
55
replicas: 1
66

@@ -9,3 +9,9 @@
99
model_name: conv
1010
compute:
1111
replicas: 1
12+
13+
- kind: api
14+
name: t2t-classifier
15+
model_name: t2t
16+
compute:
17+
replicas: 1

examples/mnist/resources/models.yaml

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
- kind: model
2-
name: dense
3-
path: implementations/models/basic.py
2+
name: dnn
3+
path: implementations/models/dnn.py
44
type: classification
55
target_column: label
66
feature_columns:
77
- image_pixels
88
hparams:
9-
layer_type: basic
109
learning_rate: 0.01
1110
input_shape: [784]
1211
output_shape: [10]
13-
hidden_units: [100, 200, 10]
12+
hidden_units: [100, 200]
1413
data_partition_ratio:
1514
training: 0.7
1615
evaluation: 0.3
17-
training:
18-
batch_size: 64
19-
num_epochs: 5
2016

2117
- kind: model
2218
name: conv
23-
path: implementations/models/basic.py
19+
path: implementations/models/custom.py
2420
type: classification
2521
target_column: label
2622
feature_columns:
@@ -30,7 +26,7 @@
3026
learning_rate: 0.01
3127
input_shape: [28, 28, 1]
3228
output_shape: [10]
33-
kernel_size: 2
29+
kernel_size: 4
3430
hidden_units: [10, 10, 10]
3531
data_partition_ratio:
3632
training: 0.7
@@ -39,18 +35,17 @@
3935
batch_size: 64
4036
num_epochs: 5
4137

38+
4239
- kind: model
43-
name: dnn
44-
path: implementations/models/dnn.py
40+
name: t2t
41+
path: implementations/models/t2t.py
4542
type: classification
4643
target_column: label
4744
feature_columns:
4845
- image_pixels
46+
prediction_key: outputs
4947
hparams:
50-
learning_rate: 0.01
5148
input_shape: [28, 28, 1]
52-
output_shape: [10]
53-
hidden_units: [100, 200]
5449
data_partition_ratio:
5550
training: 0.7
5651
evaluation: 0.3

pkg/workloads/lib/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def resource_status_key(self, resource):
460460

461461

462462
MODEL_IMPL_VALIDATION = {
463-
"required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}]
463+
"required": [{"name": "create_estimator", "args": ["run_config", "model_config"]}],
464+
"optional": [{"name": "transform_tensorflow", "args": ["features", "labels", "model_config"]}],
464465
}
465466

466467
AGGREGATOR_IMPL_VALIDATION = {

pkg/workloads/tf_api/api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,11 @@ def transform_sample(sample):
8989

9090
def create_prediction_request(transformed_sample):
9191
ctx = local_cache["ctx"]
92-
92+
signatureDef = local_cache["metadata"]["signatureDef"]
93+
signature_key = list(signatureDef.keys())[0]
9394
prediction_request = predict_pb2.PredictRequest()
9495
prediction_request.model_spec.name = "default"
95-
prediction_request.model_spec.signature_name = list(
96-
local_cache["metadata"]["signatureDef"].keys()
97-
)[0]
96+
prediction_request.model_spec.signature_name = signature_key
9897

9998
for column_name, value in transformed_sample.items():
10099
data_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[column_name]["type"]]

pkg/workloads/tf_train/train_util.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ def get_input_placeholder(model_name, ctx, training=True):
3333
return input_placeholder
3434

3535

36+
def get_label_placeholder(model_name, ctx):
37+
model = ctx.models[model_name]
38+
39+
target_column_name = model["target_column"]
40+
column_type = tf_lib.CORTEX_TYPE_TO_TF_TYPE[ctx.columns[target_column_name]["type"]]
41+
return tf.placeholder(shape=[None], dtype=column_type)
42+
43+
44+
def get_transform_tensor_fn(ctx, model_impl, model_name):
45+
model = ctx.models[model_name]
46+
model_config = ctx.model_config(model["name"])
47+
48+
def transform_tensor_fn_wrapper(inputs, labels):
49+
return model_impl.transform_tensorflow(inputs, labels, model_config)
50+
51+
return transform_tensor_fn_wrapper
52+
53+
3654
def generate_example_parsing_fn(model_name, ctx, training=True):
3755
model = ctx.models[model_name]
3856

@@ -47,7 +65,7 @@ def _parse_example(example_proto):
4765

4866

4967
# Mode must be "training" or "evaluation"
50-
def generate_input_fn(model_name, ctx, mode):
68+
def generate_input_fn(model_name, ctx, mode, model_impl):
5169
model = ctx.models[model_name]
5270

5371
filenames = ctx.get_training_data_parts(model_name, mode)
@@ -66,6 +84,9 @@ def _input_fn():
6684
if model[mode]["shuffle"]:
6785
dataset = dataset.shuffle(buffer_size)
6886

87+
if hasattr(model_impl, "transform_tensorflow"):
88+
dataset = dataset.map(get_transform_tensor_fn(ctx, model_impl, model_name))
89+
6990
dataset = dataset.batch(model[mode]["batch_size"])
7091
dataset = dataset.prefetch(buffer_size)
7192
dataset = dataset.repeat()
@@ -77,27 +98,19 @@ def _input_fn():
7798
return _input_fn
7899

79100

80-
def generate_json_serving_input_fn(model_name, ctx):
101+
def generate_json_serving_input_fn(model_name, ctx, model_impl):
81102
def _json_serving_input_fn():
82103
inputs = get_input_placeholder(model_name, ctx, training=False)
83-
features = {key: tf.expand_dims(tensor, -1) for key, tensor in inputs.items()}
84-
return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs)
85-
86-
return _json_serving_input_fn
104+
labels = get_label_placeholder(model_name, ctx)
87105

106+
features = {key: tensor for key, tensor in inputs.items()}
107+
if hasattr(model_impl, "transform_tensorflow"):
108+
features, _ = get_transform_tensor_fn(ctx, model_impl, model_name)(features, labels)
88109

89-
def generate_example_serving_input_fn(model_name, ctx):
90-
def _example_serving_input_fn():
91-
feature_spec = tf_lib.get_feature_spec(model_name, ctx, training=False)
92-
example_bytestring = tf.placeholder(shape=[None], dtype=tf.string)
93-
feature_scalars = tf.parse_single_example(example_bytestring, feature_spec)
94-
features = {key: tf.expand_dims(tensor, -1) for key, tensor in feature_scalars.items()}
95-
96-
return tf.estimator.export.ServingInputReceiver(
97-
features=features, receiver_tensors={"example_proto": example_bytestring}
98-
)
110+
features = {key: tf.expand_dims(tensor, 0) for key, tensor in features.items()}
111+
return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs)
99112

100-
return _example_serving_input_fn
113+
return _json_serving_input_fn
101114

102115

103116
def get_regression_eval_metrics(labels, predictions):
@@ -130,9 +143,9 @@ def train(model_name, model_impl, ctx, model_dir):
130143
model_dir=model_dir,
131144
)
132145

133-
train_input_fn = generate_input_fn(model_name, ctx, "training")
134-
eval_input_fn = generate_input_fn(model_name, ctx, "evaluation")
135-
serving_input_fn = generate_json_serving_input_fn(model_name, ctx)
146+
train_input_fn = generate_input_fn(model_name, ctx, "training", model_impl)
147+
eval_input_fn = generate_input_fn(model_name, ctx, "evaluation", model_impl)
148+
serving_input_fn = generate_json_serving_input_fn(model_name, ctx, model_impl)
136149
exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False)
137150

138151
dataset_metadata = aws.read_json_from_s3(model["dataset"]["metadata_key"], ctx.bucket)

0 commit comments

Comments
 (0)