Skip to content

Commit e0906bc

Browse files
committed
WIP FixedSizeArrays support
1 parent 670b480 commit e0906bc

File tree

4 files changed

+43
-0
lines changed

4 files changed

+43
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4747
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4848
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4949
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
50+
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
5051

5152
[sources]
5253
ReactantCore = {path = "lib/ReactantCore"}
@@ -57,6 +58,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
5758
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5859
ReactantDLFP8TypesExt = "DLFP8Types"
5960
ReactantFillArraysExt = "FillArrays"
61+
ReactantFixedSizeArraysExt = "FixedSizeArrays"
6062
ReactantFloat8sExt = "Float8s"
6163
ReactantKernelAbstractionsExt = "KernelAbstractions"
6264
ReactantMPIExt = "MPI"
@@ -82,6 +84,7 @@ EnumX = "1"
8284
Enzyme = "0.13.74"
8385
EnzymeCore = "0.8.13"
8486
FillArrays = "1.13"
87+
FixedSizeArrays = "1.2.0"
8588
Float8s = "0.1"
8689
Functors = "0.5"
8790
GPUArraysCore = "0.2"

ext/ReactantFixedSizeArraysExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module ReactantFixedSizeArraysExt
2+
3+
using FixedSizeArrays
4+
using Reactant
5+
using Reactant: TracedRArray, TracedRNumber, Ops
6+
using ReactantCore: ReactantCore
7+
8+
function Reactant.traced_type_inner(
9+
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}),
10+
seen,
11+
@nospecialize(mode::Reactant.TraceMode),
12+
@nospecialize(track_numbers::Type),
13+
@nospecialize(sharding),
14+
@nospecialize(runtime)
15+
) where {T, N, I}
16+
T2 = Reactant.TracedRNumber{T}
17+
I2 = Reactant.TracedRNumber{I}
18+
return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}}
19+
end
20+
21+
Base.@nospecializeinfer function Reactant.make_tracer(
22+
seen,
23+
@nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}),
24+
@nospecialize(path),
25+
mode; kwargs...
26+
) where {T, N, I}
27+
return FixedSizeArrays.FixedSizeArray(
28+
Reactant.make_tracer(
29+
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
30+
)
31+
)
32+
end
33+
34+
end

src/Tracing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,10 @@ function make_tracer(
11481148
@nospecialize(runtime = nothing),
11491149
kwargs...,
11501150
)
1151+
@show prev
1152+
@show path
1153+
@show seen
1154+
@show typeof(prev)
11511155
return make_tracer_unknown(
11521156
seen, prev, path, mode; track_numbers, sharding, runtime, kwargs...
11531157
)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
99
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1010
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
12+
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
1213
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
1314
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1415
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -28,6 +29,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2829
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2930
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3031
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
32+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
3133
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3234
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3335
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

0 commit comments

Comments
 (0)