Skip to content

Commit 7c73fd9

Browse files
committed
Fix linter
Signed-off-by: Vladimir Suvorov <suvorovv@google.com>
1 parent 86bcb8b commit 7c73fd9

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/pipeline_parallelism_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,15 @@ def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_cla
7171
rngs = nnx.Rngs(params=0)
7272
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode, rngs=rngs)
7373
else:
74-
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode)
74+
try:
75+
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode)
76+
except TypeError as exc:
77+
if "rngs" not in str(exc):
78+
raise
79+
rngs = nnx.Rngs(params=0)
80+
single_pipeline_stage = single_pipeline_stage_class(
81+
config=config, mesh=mesh, model_mode=model_mode, rngs=rngs
82+
)
7583

7684
def get_inputs(batch_size, sequence, features):
7785
"""Get random inputs, and random dummy targets

0 commit comments

Comments
 (0)