Skip to content

Commit 8c75b51

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix memory space propagation with call_exported_p. Fixes #33471
PiperOrigin-RevId: 835253850
1 parent 0169884 commit 8c75b51

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

jax/_src/export/_export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,8 @@ def pp_arg_dim(dim_idx: int | None) -> str:
14051405
out_avals = tuple(
14061406
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
14071407
*exported_dim_values),
1408-
dtype=out_aval.dtype, weak_type=out_aval.weak_type)
1408+
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
1409+
memory_space=out_aval.memory_space)
14091410
for out_aval in exported.out_avals)
14101411
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
14111412

tests/memories_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,23 @@ def test_host_to_device_transfer(self):
743743
self.assertEqual(d.sharding.memory_kind, 'device')
744744
self.assertArraysEqual(d, orig)
745745

746+
def test_memory_space_propagated_identity_jit(self):
747+
shd = jax.sharding.SingleDeviceSharding(
748+
jax.devices()[0], memory_kind='pinned_host')
749+
a = jax.device_put(1, shd)
750+
751+
f = jax.jit(lambda x: x, out_shardings=shd)
752+
b = f(a)
753+
self.assertEqual(b.sharding, a.sharding)
754+
755+
f = jax.jit(lambda x: x)
756+
b = f(a)
757+
self.assertEqual(b.sharding, a.sharding)
758+
759+
exported = jax.export.export(f)(a)
760+
b = exported.call(a)
761+
self.assertEqual(b.sharding, a.sharding)
762+
746763

747764
class ComputeOffload(jtu.BufferDonationTestCase):
748765

0 commit comments

Comments
 (0)