File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -357,7 +357,7 @@ def _call_model_generate(
357357 )
358358 for k , v in torchjax .to_torch (caches )
359359 ]
360- mask = jnp .expand_dims (mask , (1 , 2 ))
360+ mask = jnp .expand_dims (new_mask , (1 , 2 ))
361361
362362 args = (tokens , input_pos , caches_obj , mask )
363363 paramst , argst = torchjax .to_torch ((weights , args ))
@@ -371,7 +371,6 @@ def _call_model_generate(
371371 new_current_position = (
372372 current_position + 1
373373 ) % self .env .cache_sequence_length
374-
375374 return torchjax .from_torch (
376375 (
377376 res ,
@@ -816,7 +815,7 @@ def generate(
816815 length_idx = (2 * length , 2 * length + 1 ),
817816 samples_per_slot = 1 ,
818817 )
819-
818+ next_token = jax . lax . with_sharding_constraint ( next_token , self . replicated )
820819 new_decode_state = DecodeState (
821820 next_token ,
822821 new_caches ,
You can’t perform that action at this time.
0 commit comments