Skip to content

Commit ea8d5b3

Browse files
authored
[converter] Fix bazel BUILD rules for wizard (#1953)
1 parent ca35b08 commit ea8d5b3

File tree

6 files changed

+32
-27
lines changed

6 files changed

+32
-27
lines changed

tfjs-converter/python/tensorflowjs/BUILD

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,6 @@ py_test(
123123
],
124124
)
125125

126-
py_test(
127-
name = "wizard_test",
128-
srcs = ["wizard_test.py"],
129-
srcs_version = "PY2AND3",
130-
deps = [
131-
":expect_numpy_installed",
132-
":wizard",
133-
],
134-
)
135-
136-
py_binary(
137-
name = "wizard",
138-
srcs = ["wizard.py"],
139-
srcs_version = "PY2AND3",
140-
deps = [
141-
":converters/common",
142-
":converters/converter",
143-
"//tensorflowjs:expect_h5py_installed",
144-
"//tensorflowjs:expect_keras_installed",
145-
"//tensorflowjs:expect_tensorflow_installed",
146-
],
147-
)
148-
149126
# A filegroup BUILD target that includes all the op list json files in the
150127
# the op_list/ folder. The op_list folder itself is a symbolic link to the
151128
# actual op_list folder under src/.

tfjs-converter/python/tensorflowjs/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@
2121
from tensorflowjs import converters
2222
from tensorflowjs import quantization
2323
from tensorflowjs import version
24-
from tensorflowjs import wizard
2524

2625
__version__ = version.version

tfjs-converter/python/tensorflowjs/converters/BUILD

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ py_test(
7777
srcs_version = "PY2AND3",
7878
deps = [
7979
":fuse_prelu",
80+
":tf_saved_model_conversion_v2",
8081
"//tensorflowjs:expect_numpy_installed",
8182
"//tensorflowjs:expect_tensorflow_installed",
8283
],
@@ -121,13 +122,37 @@ py_library(
121122
data = ["//tensorflowjs:op_list_jsons"],
122123
srcs_version = "PY2AND3",
123124
deps = [
125+
":common",
126+
":fold_batch_norms",
127+
":fuse_prelu",
124128
"//tensorflowjs:expect_numpy_installed",
125129
"//tensorflowjs:expect_tensorflow_installed",
126130
"//tensorflowjs:expect_tensorflow_hub_installed",
127131
"//tensorflowjs:version",
128132
"//tensorflowjs:write_weights",
129-
"//tensorflowjs/converters:common",
130-
"//tensorflowjs/converters:fold_batch_norms",
133+
],
134+
)
135+
136+
py_test(
137+
name = "wizard_test",
138+
srcs = ["wizard_test.py"],
139+
srcs_version = "PY2AND3",
140+
deps = [
141+
":wizard",
142+
"//tensorflowjs:expect_numpy_installed",
143+
],
144+
)
145+
146+
py_binary(
147+
name = "wizard",
148+
srcs = ["wizard.py"],
149+
srcs_version = "PY2AND3",
150+
deps = [
151+
":common",
152+
":converter",
153+
":fuse_prelu",
154+
"//tensorflowjs:expect_h5py_installed",
155+
"//tensorflowjs:expect_tensorflow_installed",
131156
],
132157
)
133158

tfjs-converter/python/tensorflowjs/converters/converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,5 +652,6 @@ def pip_main():
652652
def main(argv):
653653
convert(argv[0].split(' '))
654654

655+
655656
if __name__ == '__main__':
656657
tf.app.run(main=main, argv=[' '.join(sys.argv[1:])])

tfjs-converter/python/tensorflowjs/wizard.py renamed to tfjs-converter/python/tensorflowjs/converters/wizard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
PyInquirer.Token.Question: '',
4848
})
4949

50+
5051
def value_in_list(answers, key, values):
5152
"""Determine user's answer for the key is in the value list.
5253
Args:
@@ -70,6 +71,7 @@ def get_tfjs_model_type(model_file):
7071
else: # Default to layers model
7172
return common.TFJS_LAYERS_MODEL_FORMAT
7273

74+
7375
def detect_saved_model(input_path):
7476
if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
7577
return common.KERAS_SAVED_MODEL
@@ -80,6 +82,7 @@ def detect_saved_model(input_path):
8082
return common.KERAS_SAVED_MODEL
8183
return common.TF_SAVED_MODEL
8284

85+
8386
def detect_input_format(input_path):
8487
"""Determine the input format from model's input path or file.
8588
Args:

tfjs-converter/python/tensorflowjs/wizard_test.py renamed to tfjs-converter/python/tensorflowjs/converters/wizard_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tensorflow.python.training.tracking import tracking
2929
from tensorflow.python.saved_model import save
3030

31-
from tensorflowjs import wizard
31+
from tensorflowjs.converters import wizard
3232

3333
SAVED_MODEL_DIR = 'saved_model'
3434
SAVED_MODEL_NAME = 'saved_model.pb'

0 commit comments

Comments
 (0)