Skip to content

Commit e8f5f00

Browse files
add local tokenizer option for automated testing without hf token (#192)
1 parent 5704b30 commit e8f5f00

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

jetstream_pt/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
"",
3434
"if set, then save the result to the given file name",
3535
)
36-
36+
flags.DEFINE_bool(
37+
"internal_use_local_tokenizer",
38+
0,
39+
"Use local tokenizer if set to True"
40+
)
3741

3842
def shard_weights(env, weights, weight_shardings):
3943
"""Shard weights according to weight_shardings"""
@@ -57,8 +61,11 @@ def create_engine(devices):
5761
FLAGS.max_input_length,
5862
FLAGS.max_output_length,
5963
)
60-
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
6164
env = environment.JetEngineEnvironment(env_data)
65+
if FLAGS.internal_use_local_tokenizer:
66+
tokenizer = AutoTokenizer.from_pretrained(env_data.checkpoint_path)
67+
else:
68+
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
6269
env.hf_tokenizer = tokenizer
6370
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
6471
# NOTE: this is assigned later because, the model should be constructed

0 commit comments

Comments
 (0)