Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5a06e7a

Browse files
Ashish Vaswanilukaszkaiser
authored andcommitted
internal.
PiperOrigin-RevId: 161130093
1 parent fbb6f9a commit 5a06e7a

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.11',
8+
version='1.0.10',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/data_generators/generator_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,10 @@ def to_example(dictionary):
4646
elif isinstance(v[0], float):
4747
features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
4848
elif isinstance(v[0], six.string_types):
49-
v = [bytes(x, 'utf-8') for x in v]
50-
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
51-
elif isinstance(v[0], bytes):
5249
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
5350
else:
54-
raise ValueError("Value for %s is neither an int nor a float; v: %s type: %s" %
55-
(k, str(v[0]), str(type(v[0]))))
51+
raise ValueError("Value is neither an int nor a float; v: %s type: %s" %
52+
(str(v[0]), str(type(v[0]))))
5653
return tf.train.Example(features=tf.train.Features(feature=features))
5754

5855

tensor2tensor/data_generators/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def image_generator(images, labels):
6868
yield {
6969
"image/encoded": [enc_string],
7070
"image/format": ["png"],
71-
"image/class/label": [int(label)],
71+
"image/class/label": [label],
7272
"image/height": [height],
7373
"image/width": [width]
7474
}

tensor2tensor/models/modalities.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def top(self, body_output, targets):
124124
class SmallImageModality(modality.Modality):
125125
"""Performs strided conv compressions for small image data."""
126126

127+
def __init__(self, model_hparams, vocab_size):
128+
super(SmallImageModality, self).__init__(model_hparams, vocab_size)
129+
self._channels = 3
130+
127131
@property
128132
def top_dimensionality(self):
129133
return 256
@@ -161,15 +165,30 @@ def targets_bottom(self, inputs):
161165

162166
def top(self, body_output, _):
163167
with tf.variable_scope("rgb_softmax"):
164-
var = tf.get_variable(
168+
# seperate embedding for each channel
169+
# assuming the body output returns a tensor of shape
170+
# [batch_size, rows, cols, channels, self._body_input_depth]
171+
body_output_split = tf.split(body_output, self._channels, axis=3)
172+
output_rgb_embedding_var = tf.get_variable(
165173
"output_rgb_embedding",
166-
[self.top_dimensionality, self._body_input_depth],
174+
[self._channels, self.top_dimensionality, self._body_input_depth],
167175
initializer=tf.random_normal_initializer(0.0, self._body_input_depth
168176
**-0.5))
169-
body_output = tf.reshape(body_output, [-1, self._body_input_depth])
170-
logits = tf.matmul(body_output, var, transpose_b=True)
177+
# compute logits separately for each channel
178+
rgb_channel_logits = []
179+
for i in self._channels:
180+
shape = tf.shape(body_output_split[i])[:-1]
181+
body_output = tf.reshape(body_output_split[i],
182+
[-1, self._body_input_depth])
183+
channel_logits = tf.matmul(body_output,
184+
output_rgb_embedding_var[i],
185+
transpose_b=True)
186+
rgb_channel_logits.append(tf.reshape(
187+
channel_logits, tf.concat([shape, [self.top_dimensionality]],
188+
0)))
189+
190+
logits = tf.concat(rgb_channel_logits, axis=3)
171191
# Reshape logits to conform to CIFAR image shapes (32 by 32 by 3)
172-
logits = tf.reshape(logits, [-1, 32, 32, 3, 256])
173192

174193
return logits
175194

0 commit comments

Comments
 (0)