File tree Expand file tree Collapse file tree 2 files changed +19
-1
lines changed
Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -1405,7 +1405,8 @@ def pp_arg_dim(dim_idx: int | None) -> str:
14051405 out_avals = tuple (
14061406 core .ShapedArray (core .evaluate_shape (out_aval .shape , exported_dim_vars ,
14071407 * exported_dim_values ),
1408- dtype = out_aval .dtype , weak_type = out_aval .weak_type )
1408+ dtype = out_aval .dtype , weak_type = out_aval .weak_type ,
1409+ memory_space = out_aval .memory_space )
14091410 for out_aval in exported .out_avals )
14101411 return out_avals , set (exported .ordered_effects + exported .unordered_effects )
14111412
Original file line number Diff line number Diff line change @@ -743,6 +743,23 @@ def test_host_to_device_transfer(self):
743743 self .assertEqual (d .sharding .memory_kind , 'device' )
744744 self .assertArraysEqual (d , orig )
745745
746+ def test_memory_space_propagated_identity_jit (self ):
747+ shd = jax .sharding .SingleDeviceSharding (
748+ jax .devices ()[0 ], memory_kind = 'pinned_host' )
749+ a = jax .device_put (1 , shd )
750+
751+ f = jax .jit (lambda x : x , out_shardings = shd )
752+ b = f (a )
753+ self .assertEqual (b .sharding , a .sharding )
754+
755+ f = jax .jit (lambda x : x )
756+ b = f (a )
757+ self .assertEqual (b .sharding , a .sharding )
758+
759+ exported = jax .export .export (f )(a )
760+ b = exported .call (a )
761+ self .assertEqual (b .sharding , a .sharding )
762+
746763
747764class ComputeOffload (jtu .BufferDonationTestCase ):
748765
You can’t perform that action at this time.
0 commit comments