@@ -3,6 +3,7 @@ module Compiler
33using Reactant_jll
44using Libdl: dlsym
55using LinearAlgebra: BlasInt
6+ using Functors: Functors
67
78import .. Reactant:
89 Reactant,
@@ -272,7 +273,7 @@ function create_result(
272273end
273274
274275function create_result (
275- tocopy:: ConcretePJRTNumber{T,D,S } ,
276+ tocopy:: ConcretePJRTNumber{T,D} ,
276277 path,
277278 result_stores,
278279 path_to_shard_info,
@@ -283,7 +284,7 @@ function create_result(
283284 result_cache,
284285 var_idx,
285286 resultgen_code,
286- ) where {T,D,S }
287+ ) where {T,D}
287288 if ! haskey (result_cache, tocopy)
288289 sym = Symbol (" result" , var_idx[])
289290 var_idx[] += 1
@@ -314,7 +315,7 @@ function create_result(
314315end
315316
316317function create_result (
317- tocopy:: ConcreteIFRTNumber{T,S } ,
318+ tocopy:: ConcreteIFRTNumber{T} ,
318319 path,
319320 result_stores,
320321 path_to_shard_info,
@@ -325,7 +326,7 @@ function create_result(
325326 result_cache,
326327 var_idx,
327328 resultgen_code,
328- ) where {T,S }
329+ ) where {T}
329330 if ! haskey (result_cache, tocopy)
330331 sym = Symbol (" result" , var_idx[])
331332 var_idx[] += 1
@@ -356,7 +357,7 @@ function create_result(
356357end
357358
358359function create_result (
359- tocopy:: ConcretePJRTArray{T,N,D,S } ,
360+ tocopy:: ConcretePJRTArray{T,N,D} ,
360361 path,
361362 result_stores,
362363 path_to_shard_info,
@@ -367,7 +368,7 @@ function create_result(
367368 result_cache,
368369 var_idx,
369370 resultgen_code,
370- ) where {T,N,D,S }
371+ ) where {T,N,D}
371372 if ! haskey (result_cache, tocopy)
372373 sym = Symbol (" result" , var_idx[])
373374 var_idx[] += 1
@@ -399,7 +400,7 @@ function create_result(
399400end
400401
401402function create_result (
402- tocopy:: ConcreteIFRTArray{T,N,S } ,
403+ tocopy:: ConcreteIFRTArray{T,N} ,
403404 path,
404405 result_stores,
405406 path_to_shard_info,
@@ -410,7 +411,7 @@ function create_result(
410411 result_cache,
411412 var_idx,
412413 resultgen_code,
413- ) where {T,N,S }
414+ ) where {T,N}
414415 if ! haskey (result_cache, tocopy)
415416 sym = Symbol (" result" , var_idx[])
416417 var_idx[] += 1
@@ -1647,12 +1648,6 @@ function compile_mlir!(
16471648 raise isa Bool && (raise = true )
16481649 end
16491650
1650- concrete_seen = OrderedIdDict ()
1651-
1652- concrete_result = make_tracer (
1653- concrete_seen, traced_result, (" result" ,), TracedToConcrete; runtime
1654- )
1655-
16561651 toolkit = XLA. CUDA_DATA_DIR[]
16571652
16581653 if backend == " cpu" || backend == " tpu"
@@ -2307,6 +2302,17 @@ function compile_mlir!(
23072302 ]
23082303 end
23092304
2305+ if result_shardings != = missing
2306+ result_shardings_after_masking = eltype (result_shardings)[]
2307+ for (i, present) in enumerate (results_mask)
2308+ if present
2309+ push! (result_shardings_after_masking, result_shardings[i])
2310+ end
2311+ end
2312+ else
2313+ result_shardings_after_masking = missing
2314+ end
2315+
23102316 func3 = MLIR. Dialects. func. func_ (;
23112317 sym_name= " main" ,
23122318 function_type= MLIR. IR. FunctionType (in_tys, out_tys2),
@@ -2384,6 +2390,10 @@ function compile_mlir!(
23842390 end
23852391 end
23862392
2393+ concrete_result = make_tracer (
2394+ OrderedIdDict (), traced_result, (" result" ,), TracedToConcrete; runtime
2395+ )
2396+
23872397 return Reactant. TracedUtils. CompiledMlirFnResult (
23882398 fnwrapped,
23892399 func3,
@@ -2404,7 +2414,7 @@ function compile_mlir!(
24042414 mlir_fn_res. unique_meshes,
24052415 mlir_fn_res. mutated_args,
24062416 use_shardy_partitioner,
2407- result_shardings ,
2417+ result_shardings_after_masking ,
24082418 mlir_fn_res. global_device_ids,
24092419 donated_args_mask,
24102420 Reactant. TracedUtils. is_pure (func3),
@@ -3151,7 +3161,6 @@ function codegen_unflatten!(
31513161 p in Reactant. TracedUtils. get_paths (arg) if length (p) > 0 && p[1 ] == :args
31523162 ))
31533163
3154- res = :result
31553164 path = path[2 : end ]
31563165
31573166 if in (path, keys (result_stores))
@@ -3246,10 +3255,15 @@ function codegen_unflatten!(
32463255 end
32473256
32483257 # generate return object which stores the concrete results in some arbitrary way
3249- return Expr[
3250- unresharded_code... , resultgen_code... , :(result = $ result_code), unflatten_code...
3251- ],
3252- used_shardinfo
3258+ return (
3259+ Expr[
3260+ unresharded_code... ,
3261+ resultgen_code... ,
3262+ :(result = $ result_code),
3263+ unflatten_code... ,
3264+ ],
3265+ used_shardinfo,
3266+ )
32533267end
32543268
32553269"""
0 commit comments