Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1e18474

Browse files
Ashish VaswaniRyan Sepassi
authored andcommitted
internal.
PiperOrigin-RevId: 161878428
1 parent 56d65f0 commit 1e18474

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tensor2tensor/models/modalities.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,31 @@ def bottom(self, x):
465465

466466
def top(self, body_output, _):
467467
return body_output
468+
469+
470+
@registry.register_image_modality("identity_no_pad")
471+
class IdentityModalityNoPad(modality.Modality):
472+
"""Does nothing except making sure that there is no padding in cross-ent."""
473+
474+
@property
475+
def targets_dimensionality(self):
476+
return self._vocab_size
477+
478+
def bottom(self, x):
479+
return tf.to_float(x)
480+
481+
def top(self, body_output, _):
482+
return body_output
483+
484+
def top_sharded(self,
485+
sharded_body_output,
486+
sharded_targets,
487+
data_parallelism,
488+
weights_fn=common_layers.weights_all):
489+
# Call the default implementation, but weight 1.0 on 0s by default.
490+
# (Since we're processing images and so have no padding and some pixel 0s.)
491+
return super(IdentityModalityNoPad, self).top_sharded(
492+
sharded_body_output,
493+
sharded_targets,
494+
data_parallelism,
495+
weights_fn=weights_fn)

0 commit comments

Comments
 (0)