Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5db92b5

Browse files
author
Ryan Sepassi
committed
Add t2t_usr_dir functionality to t2t-datagen
PiperOrigin-RevId: 162253918
1 parent 930cb33 commit 5db92b5

File tree

3 files changed

+51
-20
lines changed

3 files changed

+51
-20
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ from tensor2tensor.data_generators import wiki
4848
from tensor2tensor.data_generators import wmt
4949
from tensor2tensor.data_generators import wsj_parsing
5050
from tensor2tensor.utils import registry
51+
from tensor2tensor.utils import registry_utils
5152

5253
import tensorflow as tf
5354

@@ -63,6 +64,12 @@ flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
6364
flags.DEFINE_integer("max_cases", 0,
6465
"Maximum number of cases to generate (unbounded if 0).")
6566
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
67+
flags.DEFINE_string("t2t_usr_dir", "",
68+
"Path to a Python module that will be imported. The "
69+
"__init__.py file should include the necessary imports. "
70+
"The imported files should contain registrations, "
71+
"e.g. @registry.register_problem calls, that will then be "
72+
"available to t2t-datagen.")
6673

6774
# Mapping from problems that we can generate data for to their generators.
6875
# pylint: disable=g-long-lambda
@@ -273,6 +280,7 @@ def set_random_seed():
273280

274281
def main(_):
275282
tf.logging.set_verbosity(tf.logging.INFO)
283+
registry_utils.import_usr_dir(FLAGS.t2t_usr_dir)
276284

277285
# Calculate the list of problems to generate.
278286
problems = sorted(

tensor2tensor/bin/t2t-trainer

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@ from __future__ import absolute_import
2929
from __future__ import division
3030
from __future__ import print_function
3131

32-
import importlib
33-
import os
34-
import sys
35-
3632
# Dependency imports
3733

34+
from tensor2tensor.utils import registry_utils
3835
from tensor2tensor.utils import trainer_utils as utils
3936

4037
import tensorflow as tf
@@ -50,24 +47,9 @@ flags.DEFINE_string("t2t_usr_dir", "",
5047
"available to the t2t-trainer.")
5148

5249

53-
def import_usr_dir():
54-
"""Import module at FLAGS.t2t_usr_dir, if provided."""
55-
if not FLAGS.t2t_usr_dir:
56-
return
57-
dir_path = os.path.expanduser(FLAGS.t2t_usr_dir)
58-
if dir_path[-1] == "/":
59-
dir_path = dir_path[:-1]
60-
containing_dir, module_name = os.path.split(dir_path)
61-
tf.logging.info("Importing user module %s from path %s", module_name,
62-
containing_dir)
63-
sys.path.insert(0, containing_dir)
64-
importlib.import_module(module_name)
65-
sys.path.pop(0)
66-
67-
6850
def main(_):
6951
tf.logging.set_verbosity(tf.logging.INFO)
70-
import_usr_dir()
52+
registry_utils.import_usr_dir(FLAGS.t2t_usr_dir)
7153
utils.log_registry()
7254
utils.validate_flags()
7355
utils.run(
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2017 The Tensor2Tensor Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for the t2t registry."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import importlib
21+
import os
22+
import sys
23+
24+
# Dependency imports
25+
26+
import tensorflow as tf
27+
28+
29+
def import_usr_dir(usr_dir):
30+
"""Import module at usr_dir, if provided."""
31+
if not usr_dir:
32+
return
33+
dir_path = os.path.expanduser(usr_dir)
34+
if dir_path[-1] == "/":
35+
dir_path = dir_path[:-1]
36+
containing_dir, module_name = os.path.split(dir_path)
37+
tf.logging.info("Importing user module %s from path %s", module_name,
38+
containing_dir)
39+
sys.path.insert(0, containing_dir)
40+
importlib.import_module(module_name)
41+
sys.path.pop(0)

0 commit comments

Comments
 (0)