File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
5+ < << << << HEAD
6+ == == == =
7+
8+ >> >> >> > c6a7412e (test )
59# You may obtain a copy of the License at
610#
711# https://www.apache.org/licenses/LICENSE-2.0
@@ -109,9 +113,7 @@ def main(argv: Sequence[str]) -> None:
109113 jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
110114 os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
111115 if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os .environ .get ("LIBTPU_INIT_ARGS" , "" ):
112- os .environ ["LIBTPU_INIT_ARGS" ] = (
113- os .environ .get ("LIBTPU_INIT_ARGS" , "" ) + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
114- )
116+ os .environ ["LIBTPU_INIT_ARGS" ] = os .environ .get ("LIBTPU_INIT_ARGS" , "" ) + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
115117
116118 config = pyconfig .initialize (argv )
117119 maxtext_model , mesh = model_creation_utils .create_nnx_model (config )
You can’t perform that action at this time.
0 commit comments