Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9dc2826

Browse files
Iceparser adaptations
1 parent 30887b8 commit 9dc2826

File tree

4 files changed

+41
-24
lines changed

4 files changed

+41
-24
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename
300300
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
301301

302302
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
303-
vocab_size, tokenizer.token_counts, 1, 1e3)
303+
vocab_size, tokenizer.token_counts, 1,
304+
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
304305
vocab.store_to_file(vocab_filepath)
305306
return vocab
306307

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def default_problem_hparams():
178178
# 15: Parse tokens
179179
# 16: Icelandic characters
180180
# 17: Icelandic tokens
181+
# 18: Icelandic parse tokens
181182
# Add more above if needed.
182183
input_space_id=0,
183184
target_space_id=0,
@@ -550,20 +551,6 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
550551
return p
551552

552553

553-
def wmt_tabbed_parsing_characters(model_hparams):
554-
p = default_problem_hparams()
555-
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)}
556-
p.target_modality = (registry.Modalities.SYMBOL, 256)
557-
p.vocabulary = {
558-
"inputs": text_encoder.ByteTextEncoder(),
559-
"targets": text_encoder.ByteTextEncoder(),
560-
}
561-
p.loss_multiplier = 2.0
562-
p.input_space_id = 2
563-
p.target_space_id = 14
564-
return p
565-
566-
567554
def wsj_parsing_tokens(model_hparams, prefix,
568555
wrong_source_vocab_size,
569556
wrong_target_vocab_size):
@@ -604,6 +591,37 @@ def wsj_parsing_tokens(model_hparams, prefix,
604591
return p
605592

606593

594+
def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
595+
"""Icelandic to parse tree translation benchmark.
596+
597+
Args:
598+
model_hparams: a tf.contrib.training.HParams
599+
Returns:
600+
a tf.contrib.training.HParams
601+
"""
602+
p = default_problem_hparams()
603+
# This vocab file must be present within the data directory.
604+
source_vocab_filename = os.path.join(
605+
model_hparams.data_dir,
606+
"ice_source.tokens.vocab.%d" % wrong_source_vocab_size)
607+
target_vocab_filename = os.path.join(
608+
model_hparams.data_dir,
609+
"ice_target.tokens.vocab.256")
610+
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
611+
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
612+
p.input_modality = {
613+
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
614+
}
615+
p.target_modality = (registry.Modalities.SYMBOL, 256)
616+
p.vocabulary = {
617+
"inputs": source_subtokenizer,
618+
"targets": target_subtokenizer,
619+
}
620+
p.input_space_id = 17 # Icelandic tokens
621+
p.target_space_id = 18 # Icelandic parse tokens
622+
return p
623+
624+
607625
def image_cifar10(unused_model_hparams):
608626
"""CIFAR-10."""
609627
p = default_problem_hparams()
@@ -723,7 +741,7 @@ def img2img_imagenet(unused_model_hparams):
723741
"lmptb_10k": lmptb_10k,
724742
"wmt_parsing_characters": wmt_parsing_characters,
725743
"ice_parsing_characters": wmt_parsing_characters,
726-
"ice_parsing_tokens": lambda p: wsj_parsing_tokens(p, "ice", 2**13, 2**8),
744+
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
727745
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
728746
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, "wsj", 2**14, 2**9),
729747
"wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, "wsj", 2**15, 2**9),

tensor2tensor/data_generators/text_encoder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PAD = "<pad>"
4545
EOS = "<EOS>"
4646
RESERVED_TOKENS = [PAD, EOS]
47+
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
4748
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
4849
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1
4950

@@ -55,7 +56,7 @@
5556
class TextEncoder(object):
5657
"""Base class for converting from ints to/from human readable strings."""
5758

58-
def __init__(self, num_reserved_ids=2):
59+
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
5960
self._num_reserved_ids = num_reserved_ids
6061

6162
def encode(self, s):
@@ -130,7 +131,7 @@ def vocab_size(self):
130131
class TokenTextEncoder(TextEncoder):
131132
"""Encoder based on a user-supplied vocabulary."""
132133

133-
def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
134+
def __init__(self, vocab_filename, reverse=False, num_reserved_ids=NUM_RESERVED_TOKENS):
134135
"""Initialize from a file, one token per line."""
135136
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
136137
self._reverse = reverse
@@ -203,7 +204,7 @@ class SubwordTextEncoder(TextEncoder):
203204
204205
"""
205206

206-
def __init__(self, filename=None, num_reserved_ids=2):
207+
def __init__(self, filename=None, num_reserved_ids=NUM_RESERVED_TOKENS):
207208
"""Initialize and read from a file, if provided."""
208209
self._tokenizer = tokenizer.Tokenizer()
209210
if filename is not None:

tensor2tensor/models/transformer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,9 @@ def transformer_parsing_base():
356356
@registry.register_hparams
357357
def transformer_parsing_ice():
358358
"""Hparams for parsing Icelandic text."""
359-
hparams = transformer_parsing_base()
359+
hparams = transformer_base_single_gpu()
360360
hparams.batch_size = 4096
361-
hparams.batching_mantissa_bits = 2
362-
hparams.hidden_size = 512
363-
#hparams.max_length = 256
364-
#hparams.hidden_size = 128
361+
hparams.shared_embedding_and_softmax_weights = int(False)
365362
return hparams
366363

367364

0 commit comments

Comments
 (0)