Skip to content

Commit c196dd2

Browse files
committed
test
1 parent 188c426 commit c196dd2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/MaxText/vllm_decode.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
5+
<<<<<<< HEAD
6+
=======
7+
8+
>>>>>>> c6a7412e (test)
59
# You may obtain a copy of the License at
610
#
711
# https://www.apache.org/licenses/LICENSE-2.0
@@ -109,9 +113,7 @@ def main(argv: Sequence[str]) -> None:
109113
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
110114
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
111115
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
112-
os.environ["LIBTPU_INIT_ARGS"] = (
113-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
114-
)
116+
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
115117

116118
config = pyconfig.initialize(argv)
117119
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)

0 commit comments

Comments
 (0)