Skip to content

Commit 5676f37

Browse files
committed
WIP FixedSizeArrays support
1 parent 6b9f46b commit 5676f37

File tree

4 files changed

+43
-0
lines changed

4 files changed

+43
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3"
4646
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4747
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4848
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
49+
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
4950

5051
[sources]
5152
ReactantCore = {path = "lib/ReactantCore"}
@@ -56,6 +57,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
5657
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5758
ReactantDLFP8TypesExt = "DLFP8Types"
5859
ReactantFillArraysExt = "FillArrays"
60+
ReactantFixedSizeArraysExt = "FixedSizeArrays"
5961
ReactantFloat8sExt = "Float8s"
6062
ReactantKernelAbstractionsExt = "KernelAbstractions"
6163
ReactantMPIExt = "MPI"
@@ -80,6 +82,7 @@ EnumX = "1"
8082
Enzyme = "0.13.72"
8183
EnzymeCore = "0.8.11"
8284
FillArrays = "1.13"
85+
FixedSizeArrays = "1.2.0"
8386
Float8s = "0.1"
8487
Functors = "0.5"
8588
GPUArraysCore = "0.2"

ext/ReactantFixedSizeArraysExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module ReactantFixedSizeArraysExt
2+
3+
using FixedSizeArrays
4+
using Reactant
5+
using Reactant: TracedRArray, TracedRNumber, Ops
6+
using ReactantCore: ReactantCore
7+
8+
function Reactant.traced_type_inner(
9+
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}),
10+
seen,
11+
@nospecialize(mode::Reactant.TraceMode),
12+
@nospecialize(track_numbers::Type),
13+
@nospecialize(sharding),
14+
@nospecialize(runtime)
15+
) where {T, N, I}
16+
T2 = Reactant.TracedRNumber{T}
17+
I2 = Reactant.TracedRNumber{I}
18+
return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}}
19+
end
20+
21+
Base.@nospecializeinfer function Reactant.make_tracer(
22+
seen,
23+
@nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}),
24+
@nospecialize(path),
25+
mode; kwargs...
26+
) where {T, N, I}
27+
return FixedSizeArrays.FixedSizeArray(
28+
Reactant.make_tracer(
29+
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
30+
)
31+
)
32+
end
33+
34+
end

src/Tracing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,10 @@ function make_tracer(
11431143
@nospecialize(runtime = nothing),
11441144
kwargs...,
11451145
)
1146+
@show prev
1147+
@show path
1148+
@show seen
1149+
@show typeof(prev)
11461150
return make_tracer_unknown(
11471151
seen, prev, path, mode; track_numbers, sharding, runtime, kwargs...
11481152
)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
1011
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
1112
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1213
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -25,6 +26,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2526
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2627
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2728
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
29+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2830
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2931
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3032
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

0 commit comments

Comments
 (0)