Skip to content

Commit a1c9d49

Browse files
committed
Import fix for E2E tests
1 parent c196dd2 commit a1c9d49

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/Max
6464
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
6565
export DATASET_PATH=gs://maxtext-dataset
6666
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
67-
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
67+
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemm2-2b
6868

6969
# We can also run finetuning by using the scanned converted checkpoint.
7070
# Note that scanned checkpoint helps with efficient finetuning

end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fi
7575
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
7676
export DATASET_PATH=gs://maxtext-dataset
7777
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
78-
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
78+
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b
7979

8080
# We can also run finetuning by using the scanned converted checkpoint.
8181
# Note that scanned checkpoint helps with efficient finetuning

src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig
4949

5050
from MaxText import checkpointing
51-
from MaxText import llama_or_mistral_ckpt
51+
from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt
5252
from MaxText import max_logging
5353
from MaxText import maxtext_utils
5454
from MaxText import pyconfig

src/MaxText/vllm_decode.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
5-
<<<<<<< HEAD
6-
=======
7-
8-
>>>>>>> c6a7412e (test)
95
# You may obtain a copy of the License at
106
#
117
# https://www.apache.org/licenses/LICENSE-2.0
@@ -113,7 +109,9 @@ def main(argv: Sequence[str]) -> None:
113109
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
114110
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
115111
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
116-
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
112+
os.environ["LIBTPU_INIT_ARGS"] = (
113+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
114+
)
117115

118116
config = pyconfig.initialize(argv)
119117
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)

0 commit comments

Comments
 (0)