Skip to content

Commit 837e752

Browse files
authored
fix: despecialize sharding fields of ConcreteTypes (#1801)
* fix: despecialize sharding fields of ConcreteTypes * fix: sharding of structs now work * fix: dispatches * fix: missing * test: incorrect type * fix: preserve sharding * test: add #1227 as a testcase
1 parent 6aa98ad commit 837e752

File tree

9 files changed

+186
-192
lines changed

9 files changed

+186
-192
lines changed

docs/src/tutorials/partial-evaluation.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ addxy(x, y)
3636
3737
# output
3838
39-
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)
39+
ConcretePJRTNumber{Int64, 1}(7)
4040
```
4141

4242
returns a result that depends on both arguments `x` and `y`:
@@ -46,7 +46,7 @@ addxy(ConcreteRNumber(7), ConcreteRNumber(8))
4646
4747
# output
4848
49-
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(15)
49+
ConcretePJRTNumber{Int64, 1}(15)
5050
```
5151

5252
The StableHLO IR code generated here is:
@@ -76,7 +76,7 @@ addx4(x, 4)
7676
7777
# output
7878
79-
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)
79+
ConcretePJRTNumber{Int64, 1}(7)
8080
```
8181

8282
will only change based on `x`, not on the non-Reactant argument `y`, we get
@@ -87,7 +87,7 @@ addx4(ConcreteRNumber(7), 8)
8787
8888
# output
8989
90-
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(11)
90+
ConcretePJRTNumber{Int64, 1}(11)
9191
```
9292

9393
The StableHLO code shows that the second argument has been replaced by a

ext/ReactantCUDAExt.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,16 +1379,9 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
13791379
N = ndims(A)
13801380
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
13811381
if runtime isa Val{:PJRT}
1382-
return Reactant.ConcretePJRTArray{
1383-
T,
1384-
N,
1385-
Reactant.Sharding.ndevices(sharding),
1386-
Reactant.Sharding.shard_type(typeof(sharding), N),
1387-
}
1382+
return Reactant.ConcretePJRTArray{T,N,Reactant.Sharding.ndevices(sharding)}
13881383
elseif runtime isa Val{:IFRT}
1389-
return Reactant.ConcreteIFRTArray{
1390-
T,N,Reactant.Sharding.shard_type(typeof(sharding), N)
1391-
}
1384+
return Reactant.ConcreteIFRTArray{T,N}
13921385
end
13931386
error("Unsupported runtime $runtime")
13941387
else

src/Compiler.jl

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module Compiler
33
using Reactant_jll
44
using Libdl: dlsym
55
using LinearAlgebra: BlasInt
6+
using Functors: Functors
67

78
import ..Reactant:
89
Reactant,
@@ -272,7 +273,7 @@ function create_result(
272273
end
273274

274275
function create_result(
275-
tocopy::ConcretePJRTNumber{T,D,S},
276+
tocopy::ConcretePJRTNumber{T,D},
276277
path,
277278
result_stores,
278279
path_to_shard_info,
@@ -283,7 +284,7 @@ function create_result(
283284
result_cache,
284285
var_idx,
285286
resultgen_code,
286-
) where {T,D,S}
287+
) where {T,D}
287288
if !haskey(result_cache, tocopy)
288289
sym = Symbol("result", var_idx[])
289290
var_idx[] += 1
@@ -314,7 +315,7 @@ function create_result(
314315
end
315316

316317
function create_result(
317-
tocopy::ConcreteIFRTNumber{T,S},
318+
tocopy::ConcreteIFRTNumber{T},
318319
path,
319320
result_stores,
320321
path_to_shard_info,
@@ -325,7 +326,7 @@ function create_result(
325326
result_cache,
326327
var_idx,
327328
resultgen_code,
328-
) where {T,S}
329+
) where {T}
329330
if !haskey(result_cache, tocopy)
330331
sym = Symbol("result", var_idx[])
331332
var_idx[] += 1
@@ -356,7 +357,7 @@ function create_result(
356357
end
357358

358359
function create_result(
359-
tocopy::ConcretePJRTArray{T,N,D,S},
360+
tocopy::ConcretePJRTArray{T,N,D},
360361
path,
361362
result_stores,
362363
path_to_shard_info,
@@ -367,7 +368,7 @@ function create_result(
367368
result_cache,
368369
var_idx,
369370
resultgen_code,
370-
) where {T,N,D,S}
371+
) where {T,N,D}
371372
if !haskey(result_cache, tocopy)
372373
sym = Symbol("result", var_idx[])
373374
var_idx[] += 1
@@ -399,7 +400,7 @@ function create_result(
399400
end
400401

401402
function create_result(
402-
tocopy::ConcreteIFRTArray{T,N,S},
403+
tocopy::ConcreteIFRTArray{T,N},
403404
path,
404405
result_stores,
405406
path_to_shard_info,
@@ -410,7 +411,7 @@ function create_result(
410411
result_cache,
411412
var_idx,
412413
resultgen_code,
413-
) where {T,N,S}
414+
) where {T,N}
414415
if !haskey(result_cache, tocopy)
415416
sym = Symbol("result", var_idx[])
416417
var_idx[] += 1
@@ -1647,12 +1648,6 @@ function compile_mlir!(
16471648
raise isa Bool && (raise = true)
16481649
end
16491650

1650-
concrete_seen = OrderedIdDict()
1651-
1652-
concrete_result = make_tracer(
1653-
concrete_seen, traced_result, ("result",), TracedToConcrete; runtime
1654-
)
1655-
16561651
toolkit = XLA.CUDA_DATA_DIR[]
16571652

16581653
if backend == "cpu" || backend == "tpu"
@@ -2307,6 +2302,17 @@ function compile_mlir!(
23072302
]
23082303
end
23092304

2305+
if result_shardings !== missing
2306+
result_shardings_after_masking = eltype(result_shardings)[]
2307+
for (i, present) in enumerate(results_mask)
2308+
if present
2309+
push!(result_shardings_after_masking, result_shardings[i])
2310+
end
2311+
end
2312+
else
2313+
result_shardings_after_masking = missing
2314+
end
2315+
23102316
func3 = MLIR.Dialects.func.func_(;
23112317
sym_name="main",
23122318
function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
@@ -2384,6 +2390,10 @@ function compile_mlir!(
23842390
end
23852391
end
23862392

2393+
concrete_result = make_tracer(
2394+
OrderedIdDict(), traced_result, ("result",), TracedToConcrete; runtime
2395+
)
2396+
23872397
return Reactant.TracedUtils.CompiledMlirFnResult(
23882398
fnwrapped,
23892399
func3,
@@ -2404,7 +2414,7 @@ function compile_mlir!(
24042414
mlir_fn_res.unique_meshes,
24052415
mlir_fn_res.mutated_args,
24062416
use_shardy_partitioner,
2407-
result_shardings,
2417+
result_shardings_after_masking,
24082418
mlir_fn_res.global_device_ids,
24092419
donated_args_mask,
24102420
Reactant.TracedUtils.is_pure(func3),
@@ -3151,7 +3161,6 @@ function codegen_unflatten!(
31513161
p in Reactant.TracedUtils.get_paths(arg) if length(p) > 0 && p[1] == :args
31523162
))
31533163

3154-
res = :result
31553164
path = path[2:end]
31563165

31573166
if in(path, keys(result_stores))
@@ -3246,10 +3255,15 @@ function codegen_unflatten!(
32463255
end
32473256

32483257
# generate return object which stores the concrete results in some arbitrary way
3249-
return Expr[
3250-
unresharded_code..., resultgen_code..., :(result = $result_code), unflatten_code...
3251-
],
3252-
used_shardinfo
3258+
return (
3259+
Expr[
3260+
unresharded_code...,
3261+
resultgen_code...,
3262+
:(result = $result_code),
3263+
unflatten_code...,
3264+
],
3265+
used_shardinfo,
3266+
)
32533267
end
32543268

32553269
"""

src/ConcreteRArray.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ function Base.copy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
2323
return fn(x)
2424
end
2525

26-
function Base.copy(X::ConcreteIFRTArray{T,D,S,P}) where {T,D,S,P}
27-
return ConcreteIFRTArray{T,D,S}(Base.copy(X.data), X.shape, X.sharding, X.padding)
26+
function Base.copy(X::ConcreteIFRTArray{T,D,P}) where {T,D,P}
27+
return ConcreteIFRTArray{T,D}(Base.copy(X.data), X.shape, X.sharding, X.padding)
2828
end
2929

3030
function Base.copy(X::ConcretePJRTArray)
@@ -422,8 +422,8 @@ function Base.similar(
422422
end
423423

424424
function Base.similar(
425-
a::ConcretePJRTArray{T,N,D,Sh}, ::Type{S}=T, dims::Dims=size(a)
426-
) where {S,T,Sh,N,D}
425+
a::ConcretePJRTArray{T,N,D}, ::Type{S}=T, dims::Dims=size(a)
426+
) where {S,T,N,D}
427427
device_to_array_slices, sharding = Sharding.sharding_to_array_slices(
428428
a.sharding, dims; return_updated_sharding=Val(true), client=XLA.client(a)
429429
)
@@ -432,7 +432,7 @@ function Base.similar(
432432
Base.@_inline_meta
433433
similar(a.data[i], S, Dims(length.(device_to_array_slices[i])))
434434
end
435-
return ConcretePJRTArray{S,length(dims),D,Sh}(sdata, dims, a.sharding)
435+
return ConcretePJRTArray{S,length(dims),D}(sdata, dims, a.sharding)
436436
end
437437

438438
Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims)

src/TracedUtils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ using ..Reactant:
1717
using ..Ops: @opcall
1818
using ReactantCore: ReactantCore
1919
using ReactantCore: MissingTracedValue, is_traced, materialize_traced_array
20-
using Functors: Functors
2120

2221
ReactantCore.materialize_traced_array(x::AbstractArray) = x
2322

0 commit comments

Comments
 (0)