Skip to content

Commit 2880904

Browse files
authored
Fix ray conflict changes (#100)
* Fixed multiple host bugs * remove shard_on_batch and ragged_mha * lint fix
1 parent 517d847 commit 2880904

File tree

5 files changed

+9
-13
lines changed

5 files changed

+9
-13
lines changed

jetstream_pt/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _call_model_generate(
179179
)
180180
paramst, argst = torchjax.to_torch((weights, args))
181181
with self._lock:
182-
with torchjax.jax_mode:
182+
with torch_xla2.default_env():
183183
# The mode is needed so that tensors created inside of
184184
# the model (such as via torch.ones etc) also have the right type
185185
res = torch.func.functional_call(self.pt_model, paramst, argst)
@@ -210,7 +210,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
210210

211211
paramst, argst = torchjax.to_torch((weights, args))
212212
with self._lock:
213-
with torchjax.jax_mode:
213+
with torch_xla2.default_env():
214214
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
215215
caches_res = [c.state() for c in caches]
216216
return torchjax.from_torch((res, caches_res))

jetstream_pt/ray_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def __init__(
166166
bf16_enable=bf16_enable,
167167
sharding_config_path=sharding_config,
168168
)
169-
env = JetEngineEnvironment(env_data)
170169

171170
if model_name.startswith("llama"):
172171

@@ -353,7 +352,7 @@ def _call_model_generate(
353352
args = (tokens, input_pos, caches_obj, mask)
354353
paramst, argst = torchjax.to_torch((weights, args))
355354
with self._lock:
356-
with torchjax.jax_mode():
355+
with torch_xla2.default_env():
357356
res = torch.func.functional_call(self.pt_model, paramst, argst)
358357
updated_caches = [c.state() for c in caches_obj]
359358
scales = []
@@ -396,7 +395,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
396395

397396
paramst, argst = torchjax.to_torch((weights, args))
398397
with self._lock:
399-
with torchjax.jax_mode:
398+
with torch_xla2.default_env():
400399
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
401400
caches_res = [c.state() for c in caches]
402401
return torchjax.from_torch((res, caches_res))

jetstream_pt/torchjax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@
1414
import torch_xla2
1515
import torch_xla2.interop
1616

17-
jax_mode = torch_xla2.default_env()
18-
1917
call_jax = torch_xla2.interop.call_jax
2018
call_torch = torch_xla2.interop.call_torch
2119

2220

2321
def to_torch(tensors):
2422
"""Wrap a jax Array into XLATensor."""
25-
return jax_mode.j2t_iso(tensors)
23+
return torch_xla2.default_env().j2t_iso(tensors)
2624

2725

2826
def from_torch(tensors):
2927
"""Unwrap a XLATensor into jax Array."""
30-
return jax_mode.t2j_iso(tensors)
28+
return torch_xla2.default_env().t2j_iso(tensors)

run_interactive_multiple_host.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def create_engine():
4343
quantize_kv=FLAGS.quantize_kv_cache,
4444
max_cache_length=FLAGS.max_cache_length,
4545
sharding_config=FLAGS.sharding_config,
46-
shard_on_batch=FLAGS.shard_on_batch,
47-
ragged_mha=FLAGS.ragged_mha,
4846
)
4947

5048
print("Initialize engine", time.perf_counter() - start)
@@ -54,7 +52,7 @@ def create_engine():
5452
# pylint: disable-next=all
5553
def main(argv):
5654

57-
engine = create_engine_from_config_flags()
55+
engine = create_engine()
5856

5957
start = time.perf_counter()
6058
engine.load_params()
@@ -99,6 +97,7 @@ def main(argv):
9997
while True:
10098
# pylint: disable-next=all
10199
decode_state, result_tokens = engine.generate(None, decode_state)
100+
result_tokens = result_tokens.convert_to_numpy()
102101

103102
slot_data = result_tokens.get_result_at_slot(slot)
104103
slot_tokens = slot_data.tokens

tests/test_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_blockwise_quantized_linear_sharding(self):
237237
)
238238
def f(layer, weights, args):
239239
paramst, argst = torchjax.to_torch((weights, args))
240-
with torchjax.jax_mode:
240+
with torch_xla2.default_env():
241241
res = torch.func.functional_call(layer, paramst, argst)
242242
return torchjax.from_torch(res)
243243

0 commit comments

Comments
 (0)