Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 93b325f

Browse files
author
Ryan Sepassi
committed
Baseline model for GeneExpression problem
PiperOrigin-RevId: 163286026
1 parent 5242ac6 commit 93b325f

File tree

10 files changed

+279
-42
lines changed

10 files changed

+279
-42
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# pylint: disable=g-import-not-at-top
3535
try:
3636
# Requires h5py
37-
from tensor2tensor.data_generators import genetics
37+
from tensor2tensor.data_generators import gene_expression
3838
except ImportError:
3939
pass
4040
# pylint: enable=g-import-not-at-top

tensor2tensor/data_generators/genetics.py renamed to tensor2tensor/data_generators/gene_expression.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Genetics problems.
16+
"""Gene expression problems.
1717
1818
Inputs are bases ACTG (with indices assigned in that order).
1919
@@ -82,7 +82,7 @@ def chunk_size(self):
8282
def feature_encoders(self, data_dir):
8383
del data_dir
8484
return {
85-
"inputs": GeneticBaseEncoder(chunk_size=self.chunk_size),
85+
"inputs": DNAEncoder(chunk_size=self.chunk_size),
8686
# TODO(rsepassi): RealEncoder?
8787
"targets": text_encoder.TextEncoder()
8888
}
@@ -166,17 +166,24 @@ def example_reading_spec(self):
166166
def preprocess_examples(self, examples, mode):
167167
del mode
168168

169+
# Reshape targets
169170
examples["targets"] = tf.reshape(examples["targets"],
170171
[-1, 1, self.num_output_predictions])
172+
examples["targets_mask"] = tf.reshape(examples["targets_mask"], [-1, 1, 1])
173+
174+
# Set masked targets to 0 (i.e. pad) so that loss and metrics ignore them.
175+
# Add epsilon because some unmasked labels are actually 0.
176+
examples["targets"] += 1e-6
177+
examples["targets"] *= examples["targets_mask"]
171178

172179
return examples
173180

174181
def eval_metrics(self):
175182
return [metrics.Metrics.RMSE]
176183

177184

178-
@registry.register_problem("genetics_cage10")
179-
class GeneticsCAGE10(GeneExpressionProblem):
185+
@registry.register_problem("gene_expression_cage10")
186+
class GeneExpressionCAGE10(GeneExpressionProblem):
180187

181188
@property
182189
def download_url(self):
@@ -187,8 +194,8 @@ def h5_file(self):
187194
return "cage10.h5"
188195

189196

190-
@registry.register_problem("genetics_gm12878")
191-
class GeneticsGM12878(GeneExpressionProblem):
197+
@registry.register_problem("gene_expression_gm12878")
198+
class GeneExpressionGM12878(GeneExpressionProblem):
192199

193200
@property
194201
def download_url(self):
@@ -199,8 +206,8 @@ def h5_file(self):
199206
return "gm12878.h5"
200207

201208

202-
@registry.register_problem("genetics_l262k")
203-
class GeneticsL262k(GeneExpressionProblem):
209+
@registry.register_problem("gene_expression_l262k")
210+
class GeneExpressionL262k(GeneExpressionProblem):
204211

205212
@property
206213
def h5_file(self):
@@ -236,7 +243,7 @@ def dataset_generator(filepath,
236243
chunk_size=1,
237244
start_idx=None,
238245
end_idx=None):
239-
encoder = GeneticBaseEncoder(chunk_size=chunk_size)
246+
encoder = DNAEncoder(chunk_size=chunk_size)
240247
with h5py.File(filepath, "r") as h5_file:
241248
# Get input keys from h5_file
242249
src_keys = [s % dataset for s in ["%s_in", "%s_na", "%s_out"]]
@@ -291,7 +298,7 @@ def to_example_dict(encoder, inputs, mask, outputs):
291298
return ex_dict
292299

293300

294-
class GeneticBaseEncoder(text_encoder.TextEncoder):
301+
class DNAEncoder(text_encoder.TextEncoder):
295302
"""ACTG strings to ints and back. Optionally chunks bases into single ids.
296303
297304
Uses 'X' as an unknown base.
@@ -302,14 +309,14 @@ class GeneticBaseEncoder(text_encoder.TextEncoder):
302309
def __init__(self,
303310
chunk_size=1,
304311
num_reserved_ids=text_encoder.NUM_RESERVED_TOKENS):
305-
super(GeneticBaseEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
312+
super(DNAEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
306313
# Build a vocabulary of chunks of size chunk_size
307314
self._chunk_size = chunk_size
308315
chunks = []
309316
for size in range(1, chunk_size + 1):
310-
c = itertools.product(_bases + [GeneticBaseEncoder.UNK], repeat=size)
317+
c = itertools.product(_bases + [DNAEncoder.UNK], repeat=size)
311318
num_pad = chunk_size - size
312-
padding = (GeneticBaseEncoder.PAD,) * num_pad
319+
padding = (DNAEncoder.PAD,) * num_pad
313320
c = [el + padding for el in c]
314321
chunks.extend(c)
315322
chunks.sort()
@@ -323,7 +330,7 @@ def vocab_size(self):
323330

324331
def encode(self, s):
325332
bases = list(s)
326-
pad = [GeneticBaseEncoder.PAD] * (len(bases) % self._chunk_size)
333+
pad = [DNAEncoder.PAD] * (len(bases) % self._chunk_size)
327334
bases.extend(pad)
328335
assert (len(bases) % self._chunk_size) == 0
329336
num_chunks = len(bases) // self._chunk_size
@@ -342,8 +349,8 @@ def decode(self, ids):
342349
for idx in ids:
343350
if idx >= self._num_reserved_ids:
344351
chunk = self._ids_to_chunk[idx]
345-
if GeneticBaseEncoder.PAD in chunk:
346-
chunk = chunk[:chunk.index(GeneticBaseEncoder.PAD)]
352+
if DNAEncoder.PAD in chunk:
353+
chunk = chunk[:chunk.index(DNAEncoder.PAD)]
347354
else:
348355
chunk = [text_encoder.RESERVED_TOKENS[idx]]
349356
bases.extend(chunk)

tensor2tensor/data_generators/genetics_test.py renamed to tensor2tensor/data_generators/gene_expression_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import numpy as np
2424

25-
from tensor2tensor.data_generators import genetics
25+
from tensor2tensor.data_generators import gene_expression
2626

2727
import tensorflow as tf
2828

@@ -40,15 +40,15 @@ def _oneHotBases(self, bases):
4040
return np.array(one_hots)
4141

4242
def testRecordToExample(self):
43-
encoder = genetics.GeneticBaseEncoder(chunk_size=2)
43+
encoder = gene_expression.DNAEncoder(chunk_size=2)
4444
raw_inputs = ["A", "C", "G", "X", "C", "T"]
4545

4646
# Put in numpy arrays in the same format as in the h5 file
4747
inputs = self._oneHotBases(raw_inputs)
4848
mask = np.array([True, False, True])
4949
outputs = np.array([[1.0, 2.0, 3.0], [5.0, 1.0, 0.2], [5.1, 2.3, 2.3]])
5050
# Convert to example dict
51-
ex_dict = genetics.to_example_dict(encoder, inputs, mask, outputs)
51+
ex_dict = gene_expression.to_example_dict(encoder, inputs, mask, outputs)
5252

5353
self.assertEqual(len(raw_inputs) // 2 + 1, len(ex_dict["inputs"]))
5454
self.assertAllEqual(encoder.encode(raw_inputs) + [1], ex_dict["inputs"])
@@ -61,7 +61,7 @@ def testGenerateShardArgs(self):
6161
num_examples = 37
6262
num_shards = 4
6363
outfiles = [str(i) for i in range(num_shards)]
64-
shard_args = genetics.generate_shard_args(outfiles, num_examples)
64+
shard_args = gene_expression.generate_shard_args(outfiles, num_examples)
6565

6666
starts, ends, fnames = zip(*shard_args)
6767
self.assertAllEqual([0, 9, 18, 27], starts)

tensor2tensor/models/common_layers.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,10 @@ def get_norm(norm_type):
469469
"'noam', 'none'.")
470470

471471

472-
def residual_fn(x, y, norm_type, residual_dropout,
472+
def residual_fn(x,
473+
y,
474+
norm_type,
475+
residual_dropout,
473476
filters=None,
474477
epsilon=1e-16,
475478
name="residual"):
@@ -559,11 +562,17 @@ def conv_block_internal(conv_fn,
559562

560563

561564
def conv_block(inputs, filters, dilation_rates_and_kernel_sizes, **kwargs):
562-
"""A block of standard convolutions."""
565+
"""A block of standard 2d convolutions."""
563566
return conv_block_internal(conv, inputs, filters,
564567
dilation_rates_and_kernel_sizes, **kwargs)
565568

566569

570+
def conv1d_block(inputs, filters, dilation_rates_and_kernel_sizes, **kwargs):
571+
"""A block of standard 1d convolutions."""
572+
return conv_block_internal(conv1d, inputs, filters,
573+
dilation_rates_and_kernel_sizes, **kwargs)
574+
575+
567576
def separable_conv_block(inputs, filters, dilation_rates_and_kernel_sizes,
568577
**kwargs):
569578
"""A block of separable convolutions."""
@@ -858,10 +867,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes,
858867
return tf.add_n(results) * (len(results)**-0.5)
859868

860869

861-
def multiscale_conv_and_attention(x,
862-
padding,
863-
hparams,
864-
source=None):
870+
def multiscale_conv_and_attention(x, padding, hparams, source=None):
865871
"""A common part of t2t layers.
866872
867873
First, do a linear multiscale convolution
@@ -925,10 +931,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type,
925931
return tf.add_n(results) * (len(results)**-0.5)
926932

927933

928-
def conv_with_pools_and_attention(x,
929-
padding,
930-
hparams,
931-
source=None):
934+
def conv_with_pools_and_attention(x, padding, hparams, source=None):
932935
"""A common part of t2t layers.
933936
934937
First, do conv_with_pools
@@ -1389,8 +1392,8 @@ def padded_cross_entropy(logits,
13891392
vocab_size = tf.shape(logits)[-1]
13901393
with tf.name_scope("padded_cross_entropy", [logits, labels]):
13911394
pad_logits, pad_labels = pad_with_zeros(logits, labels)
1392-
xent = smoothing_cross_entropy(pad_logits, pad_labels,
1393-
vocab_size, confidence)
1395+
xent = smoothing_cross_entropy(pad_logits, pad_labels, vocab_size,
1396+
confidence)
13941397
weights = weights_fn(pad_labels)
13951398
if not reduce_sum:
13961399
return xent * weights, weights
@@ -1493,8 +1496,8 @@ def linear_set_layer(layer_size,
14931496
# Unfortunately tf doesn't support broadcasting via concat, but we can
14941497
# simply add the transformed context to get the same effect.
14951498
context = tf.expand_dims(context, axis=1)
1496-
cont_tfm = conv1d(context, layer_size, 1,
1497-
activation=None, name="cont_conv")
1499+
cont_tfm = conv1d(
1500+
context, layer_size, 1, activation=None, name="cont_conv")
14981501
outputs += cont_tfm
14991502

15001503
if activation_fn is not None:
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Models for gene expression from DNA."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
from six.moves import xrange # pylint: disable=redefined-builtin
24+
25+
from tensor2tensor.models import common_hparams
26+
from tensor2tensor.models import common_layers
27+
from tensor2tensor.utils import registry
28+
from tensor2tensor.utils import t2t_model
29+
30+
import tensorflow as tf
31+
32+
33+
@registry.register_model
34+
class GeneExpressionConv(t2t_model.T2TModel):
35+
"""Gene expression conv net.
36+
37+
Based on "Basenji" model from
38+
http://www.biorxiv.org/content/early/2017/07/10/161851
39+
40+
Uses layer_norm instead of batch_norm.
41+
"""
42+
43+
def model_fn_body(self, features):
44+
inputs = features["inputs"]
45+
inputs.get_shape().assert_has_rank(4)
46+
47+
hp = self._hparams
48+
49+
out = inputs
50+
out = common_layers.flatten4d3d(out)
51+
52+
# Conv layers
53+
for i in xrange(hp.num_conv_layers):
54+
out = conv_layer(
55+
out,
56+
hp.hidden_size,
57+
hp.kernel_width,
58+
hp.stride,
59+
hp.pooling_windows[i],
60+
hp.dropout,
61+
1,
62+
name="conv_%d" % (i + 1))
63+
64+
# Dense dilated conv layers
65+
for i in xrange(hp.num_dconv_layers):
66+
dilation_rate = 2**(i + 1)
67+
dconv_out = conv_layer(
68+
out,
69+
hp.hidden_size,
70+
hp.kernel_width,
71+
1,
72+
0,
73+
hp.dropout,
74+
dilation_rate,
75+
name="dconv_%d" % (i + 1))
76+
out = tf.concat([out, dconv_out], axis=2)
77+
78+
# Fully connected layer
79+
out = fc_layer(out, hp.hidden_size, hp.dropout, name="fc")
80+
81+
out.get_shape().assert_has_rank(3)
82+
out = tf.expand_dims(out, 2)
83+
return out
84+
85+
86+
def conv_layer(x,
87+
hidden_size,
88+
kernel_size,
89+
stride,
90+
pooling_window,
91+
dropout_rate,
92+
dilation_rate,
93+
name="conv"):
94+
with tf.variable_scope(name):
95+
out = x
96+
out = common_layers.conv1d_block(
97+
out,
98+
hidden_size, [(dilation_rate, kernel_size)],
99+
strides=stride,
100+
first_relu=False,
101+
padding="same")
102+
out = tf.nn.relu(out)
103+
if pooling_window:
104+
out = tf.layers.max_pooling1d(
105+
out, pooling_window, pooling_window, padding="same")
106+
out = tf.layers.dropout(out, dropout_rate)
107+
return out
108+
109+
110+
def fc_layer(x, num_out, dropout_rate, name="fc"):
111+
with tf.variable_scope(name):
112+
out = x
113+
out = tf.layers.dense(out, num_out)
114+
out = tf.contrib.layers.layer_norm(out)
115+
out = tf.nn.relu(out)
116+
out = tf.layers.dropout(out, dropout_rate)
117+
return out
118+
119+
120+
@registry.register_hparams
121+
def gene_expression_conv_base():
122+
"""Hparams for GeneExpressionConv model."""
123+
hparams = common_hparams.basic_params1()
124+
hparams.add_hparam("num_conv_layers", 4)
125+
hparams.add_hparam("num_dconv_layers", 7)
126+
hparams.add_hparam("pooling_windows", [2, 4, 4, 4])
127+
128+
# TODO(rsepassi): Correct the values of these hyperparameters
129+
hparams.hidden_size = 128
130+
hparams.kernel_width = 128
131+
hparams.add_hparam("stride", 1)
132+
return hparams

0 commit comments

Comments
 (0)