Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b39d152

Browse files
author
Ryan Sepassi
committed
Update comment on shape in SymbolModality
PiperOrigin-RevId: 191759697
1 parent fc9335c commit b39d152

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensor2tensor/layers/modalities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ def _get_weights(self, hidden_dim=None):
9595

9696
def bottom_simple(self, x, name, reuse):
9797
with tf.variable_scope(name, reuse=reuse):
98-
# Squeeze out the channels dimension.
98+
# Ensure the inputs are 3-D
9999
if len(x.get_shape()) == 4:
100100
x = tf.squeeze(x, axis=3)
101101
while len(x.get_shape()) < 3:
102102
x = tf.expand_dims(x, axis=-1)
103+
103104
var = self._get_weights()
104105
x = common_layers.dropout_no_scaling(
105106
x, 1.0 - self._model_hparams.symbol_dropout)

0 commit comments

Comments
 (0)