File tree Expand file tree Collapse file tree 1 file changed +2
-12
lines changed
Expand file tree Collapse file tree 1 file changed +2
-12
lines changed Original file line number Diff line number Diff line change @@ -2060,24 +2060,15 @@ def jaxpr_transfer_mem_kinds(jaxpr: core.Jaxpr):
20602060 return out
20612061
20622062
2063- def are_all_shardings_default_mem_kind (
2064- device_list : xc .DeviceList | None , shardings
2065- ):
2066- if device_list is None :
2067- return True
2068- try :
2069- default_mem_kind = device_list .default_memory_kind
2070- except :
2071- return True
2072-
2063+ def are_all_shardings_default_mem_kind (shardings ):
20732064 for i in shardings :
20742065 if isinstance (i , (UnspecifiedValue , AUTO )):
20752066 continue
20762067 mem_kind = (core .mem_space_to_kind (i ) if isinstance (i , core .MemorySpace )
20772068 else i .memory_kind )
20782069 if mem_kind is None :
20792070 continue
2080- if mem_kind != default_mem_kind :
2071+ if mem_kind != 'device' :
20812072 return False
20822073 return True
20832074
@@ -2369,7 +2360,6 @@ def lower_sharding_computation(
23692360 )
23702361
23712362 all_default_mem_kind = are_all_shardings_default_mem_kind (
2372- device_list ,
23732363 it .chain (unique_in_shardings , unique_out_shardings ,
23742364 unique_intermediate_shardings , transfer_mem_kind_in_jaxpr )) # pytype: disable=wrong-arg-types
23752365
You can’t perform that action at this time.
0 commit comments