Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ee947c9

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Fix transformer's encode docstring.
PiperOrigin-RevId: 179928442
1 parent f2b620f commit ee947c9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensor2tensor/models/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def encode(self, inputs, target_space, hparams, features=None):
5353
"""Encode transformer inputs.
5454
5555
Args:
56-
inputs: Transformer inputs [batch_size, input_length, hidden_dim]
56+
inputs: Transformer inputs [batch_size, input_length, input_height,
57+
hidden_dim] which will be flattened along the two spatial dimensions.
5758
target_space: scalar, target space ID.
5859
hparams: hyperparmeters for model.
5960
features: optionally pass the entire features dictionary as well.

0 commit comments

Comments
 (0)