Skip to content

Commit 25aa26e

Browse files
yashk2810Google-ML-Automation
authored andcommitted
default mem kind is always device at HEAD so just hardcode to that for all device types
PiperOrigin-RevId: 835431225
1 parent 41ac9b0 commit 25aa26e

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,24 +2060,15 @@ def jaxpr_transfer_mem_kinds(jaxpr: core.Jaxpr):
20602060
return out
20612061

20622062

2063-
def are_all_shardings_default_mem_kind(
2064-
device_list: xc.DeviceList | None, shardings
2065-
):
2066-
if device_list is None:
2067-
return True
2068-
try:
2069-
default_mem_kind = device_list.default_memory_kind
2070-
except:
2071-
return True
2072-
2063+
def are_all_shardings_default_mem_kind(shardings):
20732064
for i in shardings:
20742065
if isinstance(i, (UnspecifiedValue, AUTO)):
20752066
continue
20762067
mem_kind = (core.mem_space_to_kind(i) if isinstance(i, core.MemorySpace)
20772068
else i.memory_kind)
20782069
if mem_kind is None:
20792070
continue
2080-
if mem_kind != default_mem_kind:
2071+
if mem_kind != 'device':
20812072
return False
20822073
return True
20832074

@@ -2369,7 +2360,6 @@ def lower_sharding_computation(
23692360
)
23702361

23712362
all_default_mem_kind = are_all_shardings_default_mem_kind(
2372-
device_list,
23732363
it.chain(unique_in_shardings, unique_out_shardings,
23742364
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
23752365

0 commit comments

Comments
 (0)