@@ -166,7 +166,6 @@ def __init__(
166166 bf16_enable = bf16_enable ,
167167 sharding_config_path = sharding_config ,
168168 )
169- env = JetEngineEnvironment (env_data )
170169
171170 if model_name .startswith ("llama" ):
172171
@@ -353,7 +352,7 @@ def _call_model_generate(
353352 args = (tokens , input_pos , caches_obj , mask )
354353 paramst , argst = torchjax .to_torch ((weights , args ))
355354 with self ._lock :
356- with torchjax . jax_mode ():
355+ with torch_xla2 . default_env ():
357356 res = torch .func .functional_call (self .pt_model , paramst , argst )
358357 updated_caches = [c .state () for c in caches_obj ]
359358 scales = []
@@ -396,7 +395,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
396395
397396 paramst , argst = torchjax .to_torch ((weights , args ))
398397 with self ._lock :
399- with torchjax . jax_mode :
398+ with torch_xla2 . default_env () :
400399 res = torch .func .functional_call (self .pt_model , paramst , argst )[0 ]
401400 caches_res = [c .state () for c in caches ]
402401 return torchjax .from_torch ((res , caches_res ))
0 commit comments