3333
3434from jetstream_pt import cache_manager
3535from jetstream_pt import quantize
36+ from jetstream_pt import torchjax
3637from jetstream_pt .environment import JetEngineEnvironment , JetEngineEnvironmentData
3738from jetstream_pt .third_party .llama import model_exportable , model_args
3839from jetstream_pt .third_party .gemma import config as gemma_config , model as gemma_model
@@ -86,8 +87,11 @@ def __init__(
8687 self .y_sharding = env .sharding_by_axis (1 )
8788 self .x_sharding = env .sharding_by_axis (0 )
8889 self .replicated = env .sharding_by_axis (- 1 ) # replicated
90+
8991 self .cache_sharding = self .env .cache_sharding
9092
93+ jax .config .update ("jax_enable_x64" , False )
94+
9195 self .prefill = jax .jit (
9296 self .prefill , out_shardings = self .get_prefix_destination_sharding ()
9397 )
@@ -147,7 +151,7 @@ def _call_model_generate(
147151 if self .env .enable_kv_quantization :
148152 caches_obj = [
149153 cache_manager .Int8KVCacheGenerate (k , v , ks , vs , input_indexes )
150- for (k , v ), (ks , vs ) in torch_xla2 . tensor . wrap (
154+ for (k , v ), (ks , vs ) in torchjax . to_torch (
151155 list (zip (caches , cache_scales ))
152156 )
153157 ]
@@ -156,20 +160,22 @@ def _call_model_generate(
156160 cache_manager .KVCacheGenerate (
157161 k , v , input_indexes , self .cache_sharding
158162 )
159- for k , v in torch_xla2 . tensor . wrap (caches )
163+ for k , v in torchjax . to_torch (caches )
160164 ]
161165 mask = jnp .expand_dims (mask , (1 , 2 ))
162166
163167 args = (tokens , input_pos , caches_obj , mask )
164- paramst , argst = torch_xla2 . tensor . wrap ((weights , args ))
168+ paramst , argst = torchjax . to_torch ((weights , args ))
165169 with self ._lock :
166- with torch_xla2 .tensor .XLADispatchMode ():
170+ with torchjax .jax_mode :
171+ # The mode is needed so that tensors created inside of
172+ # the model (such as via torch.ones etc) also have the right type
167173 res = torch .func .functional_call (self .pt_model , paramst , argst )
168174 updated_caches = [c .state () for c in caches_obj ]
169175 scales = []
170176 if self .env .enable_kv_quantization :
171177 scales = [c .scalers () for c in caches_obj ]
172- return torch_xla2 . tensor . unwrap ((res , updated_caches , scales ))
178+ return torchjax . from_torch ((res , updated_caches , scales ))
173179
174180 @functools .partial (
175181 jax .jit ,
@@ -188,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
188194 mask = jnp .triu (mask , k = 1 )
189195 args = (tokens , input_indexes , caches , mask )
190196
191- paramst , argst = torch_xla2 . tensor . wrap ((weights , args ))
197+ paramst , argst = torchjax . to_torch ((weights , args ))
192198 with self ._lock :
193- with torch_xla2 . tensor . XLADispatchMode () :
199+ with torchjax . jax_mode :
194200 res = torch .func .functional_call (self .pt_model , paramst , argst )[0 ]
195201 caches_res = [c .state () for c in caches ]
196- return torch_xla2 . tensor . unwrap ((res , caches_res ))
202+ return torchjax . from_torch ((res , caches_res ))
197203
198204 def _sampling (self , logits : Any , batch_size : int ) -> jnp .ndarray :
199205 if len (logits .shape ) == 2 :
@@ -287,20 +293,20 @@ def insert(cache, new_entry):
287293 @functools .partial (jax .jit , donate_argnums = (0 , 1 ), inline = True )
288294 def insert (cache , scaler , new_entry ):
289295 reduce_axis = (1 , 3 )
290- vals , scales = torch_xla2 .extra .call_torch (
296+ vals , scales = torch_xla2 .interop .call_torch (
291297 quantize .quantize_torch_int8 , new_entry , reduce_axis
292298 )
293299 new_scaler = jax .lax .dynamic_update_slice (
294300 scaler ,
295- scales ,
301+ scales . jax () ,
296302 [slot , 0 , pos , 0 ],
297303 )
298304 new_scaler = jax .lax .with_sharding_constraint (
299305 new_scaler , self .replicated
300306 )
301307 res = jax .lax .dynamic_update_slice (
302308 cache ,
303- vals ,
309+ vals . jax () ,
304310 [slot , 0 , pos , 0 ],
305311 )
306312 res = jax .lax .with_sharding_constraint (res , self .cache_sharding )
@@ -386,7 +392,7 @@ def insert(cache, new_entry):
386392 def insert (cache , scaler , new_entry ):
387393 new_entry = jnp .transpose (new_entry .squeeze (0 ), (1 , 0 , 2 ))
388394 reduce_axis = (1 , 2 )
389- vals , scales = torch_xla2 .extra .call_torch (
395+ vals , scales = torch_xla2 .interop .call_torch (
390396 quantize .quantize_torch_int8 , new_entry , reduce_axis
391397 )
392398 new_scaler = scaler .at [slot , :, update_indexes , :].set (scales )
@@ -559,7 +565,7 @@ def _load_from_state_dict(self, path):
559565 for key , model_weights in self .pt_model .state_dict ().items ():
560566 assert key in state_dict , f"key: { key } not found"
561567 arr = jax .device_put (
562- torch_xla2 . tensor . t2j (state_dict [key ]), self .env .sharding_by_name (key )
568+ torchjax . from_torch (state_dict [key ]), self .env .sharding_by_name (key )
563569 )
564570 assert tuple (model_weights .shape ) == tuple (
565571 arr .shape
@@ -602,14 +608,14 @@ def get_prefix_destination_sharding(self) -> Prefix:
602608 """Returns the shardings necessary to transfer data between engines."""
603609 return Prefix (
604610 self .replicated ,
605- self .cache_sharding ,
611+ self .replicated if self . env . shard_on_batch else self . cache_sharding ,
606612 self .replicated ,
607613 )
608614
609615 def get_decode_state_sharding (self ) -> DecodeState :
610616 """Gets the shardings corresponding to the decode state."""
611617 return DecodeState (
612- self .replicated ,
618+ self .x_sharding if self . env . shard_on_batch else self . replicated ,
613619 self .cache_sharding ,
614620 self .replicated ,
615621 self .replicated ,
@@ -663,6 +669,7 @@ def create_pytorch_engine(
663669 quantize_kv = False ,
664670 max_cache_length = 1024 ,
665671 sharding_config = None ,
672+ shard_on_batch = False ,
666673) -> PyTorchEngine :
667674 """Returns: The pytorch engine."""
668675
@@ -718,8 +725,12 @@ def create_pytorch_engine(
718725 cache_sequence_length = max_cache_length ,
719726 bf16_enable = bf16_enable ,
720727 sharding_config_path = sharding_config ,
728+ shard_on_batch = shard_on_batch ,
721729 )
722730
731+ if shard_on_batch and sharding_config :
732+ print ("WARNING: with sharding_on_batch sharding config is ignored." )
733+
723734 if model_name .startswith ("llama" ):
724735
725736 args = model_args .get_model_args (
0 commit comments