Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c762954

Browse files
w-hatcopybara-github
authored andcommitted
Documentation update.
PiperOrigin-RevId: 311614763
1 parent a9da963 commit c762954

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

docs/new_model.md

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,42 @@ Here we show how to create your own model in T2T.
1515

1616
`T2TModel` has three typical usages:
1717

18-
1. Estimator: The method `make_estimator_model_fn` builds a `model_fn` for
19-
the tf.Estimator workflow of training, evaluation, and prediction.
20-
It performs the method `call`, which performs the core computation,
21-
followed by `estimator_spec_train`, `estimator_spec_eval`, or
22-
`estimator_spec_predict` depending on the tf.Estimator mode.
23-
2. Layer: The method `call` enables `T2TModel` to be used a callable by
24-
itself. It calls the following methods:
25-
26-
* `bottom`, which transforms features according to `problem_hparams`' input
27-
and target `Modality`s;
28-
* `body`, which takes features and performs the core model computation to
18+
1. Estimator: The method `make_estimator_model_fn` builds a `model_fn` for the
19+
tf.Estimator workflow of training, evaluation, and prediction. It performs
20+
the method `call`, which performs the core computation, followed by
21+
`estimator_spec_train`, `estimator_spec_eval`, or `estimator_spec_predict`
22+
depending on the tf.Estimator mode.
23+
2. Layer: The method `call` enables `T2TModel` to be used a callable by itself.
24+
It calls the following methods:
25+
26+
* `bottom`, which transforms features according to `problem_hparams`'
27+
input and target `Modality`s;
28+
* `body`, which takes features and performs the core model computation to
2929
return output and any auxiliary loss terms;
30-
* `top`, which takes features and the body output, and transforms them
31-
according to `problem_hparams`' input and target `Modality`s to return
32-
the final logits;
33-
* `loss`, which takes the logits, forms any missing training loss, and sums
34-
all loss terms.
35-
3. Inference: The method `infer` enables `T2TModel` to make sequence
36-
predictions by itself.
30+
* `top`, which takes features and the body output, and transforms them
31+
according to `problem_hparams`' input and target `Modality`s to return
32+
the final logits;
33+
* `loss`, which takes the logits, forms any missing training loss, and
34+
sums all loss terms.
3735

36+
3. Inference: The method `infer` enables `T2TModel` to make sequence
37+
predictions by itself.
3838

3939
## Creating your own model
4040

41-
1. Create class that extends T2TModel
42-
in this example it will be a copy of existing basic fully connected network:
41+
1. Create a class that extends `T2TModel`. This example creates a copy of an
42+
existing basic fully-connected network:
4343

44-
```python
44+
```python
4545
from tensor2tensor.utils import t2t_model
4646

4747
class MyFC(t2t_model.T2TModel):
4848
pass
49-
```
49+
```
5050

51+
2. Implement the `body` method:
5152

52-
2. Implement body method:
53-
54-
```python
53+
```python
5554
class MyFC(t2t_model.T2TModel):
5655
def body(self, features):
5756
hparams = self.hparams
@@ -63,43 +62,46 @@ Here we show how to create your own model in T2T.
6362
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
6463
x = tf.nn.relu(x)
6564
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.
66-
```
65+
```
66+
67+
Method Signature:
6768

69+
* Args:
6870

69-
Method Signature:
71+
* features: dict of str to Tensor, where each Tensor has shape
72+
[batch_size, ..., hidden_size]. It typically contains keys `inputs`
73+
and `targets`.
7074

71-
* Args:
72-
* features: dict of str to Tensor, where each Tensor has shape [batch_size,
73-
..., hidden_size]. It typically contains keys `inputs` and `targets`.
75+
* Returns one of:
7476

75-
* Returns one of:
76-
* output: Tensor of pre-logit activations with shape [batch_size, ...,
77-
hidden_size].
78-
* losses: Either single loss as a scalar, a list, a Tensor (to be averaged),
79-
or a dictionary of losses. If losses is a dictionary with the key
80-
"training", losses["training"] is considered the final training
81-
loss and output is considered logits; self.top and self.loss will
82-
be skipped.
77+
* output: Tensor of pre-logit activations with shape [batch_size, ...,
78+
hidden_size].
79+
* losses: Either single loss as a scalar, a list, a Tensor (to be
80+
averaged), or a dictionary of losses. If losses is a dictionary with
81+
the key "training", losses["training"] is considered the final
82+
training loss and output is considered logits; self.top and
83+
self.loss will be skipped.
8384

84-
3. Register your model
85+
3. Register your model:
8586

86-
```python
87+
```python
8788
from tensor2tensor.utils import registry
8889

8990
@registry.register_model
9091
class MyFC(t2t_model.T2TModel):
9192
# ...
92-
```
93-
93+
```
9494

95-
3. Use it with t2t tools as any other model
95+
4. Use it with t2t tools as any other model:
9696

97-
Have in mind that names are translated from camel case to snake_case `MyFC` -> `my_fc`
98-
and that you need to point t2t to directory containing your model with `t2t_usr_dir` switch.
99-
For example if you want to train model on gcloud with 1 GPU worker on IMDB sentiment task you can run your model
100-
by executing following command from your model class directory.
97+
Have in mind that names are translated from camel case to snake_case `MyFC`
98+
-> `my_fc` and that you need to point t2t to the directory containing your
99+
model with the `--t2t_usr_dir` flag. For example if you want to train a
100+
model on gcloud with 1 GPU worker on the IMDB sentiment task, you can run
101+
your model by executing the following command from your model class
102+
directory.
101103

102-
```bash
104+
```bash
103105
t2t-trainer \
104106
--model=my_fc \
105107
--t2t_usr_dir=.
@@ -111,4 +113,4 @@ Method Signature:
111113
--hparams_set=basic_fc_small \
112114
--train_steps=10000 \
113115
--eval_steps=10 \
114-
```
116+
```

0 commit comments

Comments
 (0)