11module ReactantOneHotArraysExt
22
3- using OneHotArrays: OneHotArray
4- using Reactant: Reactant, TracedRArray, TracedRNumber, Ops
3+ using GPUArraysCore: @allowscalar
4+ using OneHotArrays: OneHotArrays, OneHotArray
5+ using Reactant: Reactant, AnyTracedRArray, TracedRArray, TracedRNumber
56using ReactantCore: ReactantCore
67using Reactant. Ops: @opcall
78
9+ __compatible_eltype (:: Type{T} , :: Type{U} ) where {T,U} = T
10+ function __compatible_eltype (:: Type{TracedRNumber{T}} , :: Type{TracedRNumber{U}} ) where {T,U}
11+ return TracedRNumber{T}
12+ end
13+ __compatible_eltype (:: Type{TracedRNumber{T}} , :: Type{U} ) where {T,U} = T
14+ __compatible_eltype (:: Type{T} , :: Type{TracedRNumber{U}} ) where {T,U} = TracedRNumber{T}
15+
816function Reactant. traced_type_inner (
917 @nospecialize (_:: Type{OneHotArray{T,N,Np1,I}} ),
1018 seen,
@@ -14,12 +22,7 @@ function Reactant.traced_type_inner(
1422 @nospecialize (runtime)
1523) where {T,N,Np1,I}
1624 I2 = Reactant. traced_type_inner (I, seen, mode, track_numbers, sharding, runtime)
17- T2 = if eltype (I2) <: Reactant.TracedRNumber && ! (T <: Reactant.TracedRNumber )
18- Reactant. TracedRNumber{T}
19- else
20- T
21- end
22- return OneHotArray{T2,N,Np1,I2}
25+ return OneHotArray{__compatible_eltype (T, eltype (I2)),N,Np1,I2}
2326end
2427
2528function ReactantCore. materialize_traced_array (r:: OneHotArray )
@@ -45,4 +48,69 @@ function Base.Array(
4548 return Array (reshape (Array (r. indices), 1 , size (r. indices)... ) .== 1 : (r. nlabels))
4649end
4750
51+ function OneHotArrays. onehotbatch (data:: AnyTracedRArray{<:Any,N} , labels) where {N}
52+ # TODO : add checkbounds once we support that with TracedRNumber
53+ labels_expanded = @opcall broadcast_in_dim (
54+ Reactant. promote_to (
55+ TracedRArray{Reactant. unwrapped_eltype (labels),1 },
56+ ReactantCore. materialize_traced_array (vec (labels)),
57+ ),
58+ Int64[1 ],
59+ [length (labels), size (data)... ],
60+ )
61+ data = ReactantCore. materialize_traced_array (reshape (data, 1 , size (data)... ))
62+ indices = UInt32 .(@opcall (findfirst (data .== labels_expanded; dimension= 1 )))
63+ return OneHotArray {TracedRNumber{UInt32},N,N + 1,typeof(indices)} (
64+ indices, length (labels)
65+ )
66+ end
67+
68+ function OneHotArrays. onehotbatch (
69+ data:: AnyTracedRArray{<:Integer,N} , labels:: AbstractUnitRange{<:Integer}
70+ ) where {N}
71+ # TODO : add checkbounds once we support that with TracedRNumber
72+ indices = map (
73+ TracedRNumber{UInt32} ∘ Base. Fix2 (+ , 1 - first (labels)),
74+ ReactantCore. materialize_traced_array (data),
75+ )
76+ return OneHotArray {TracedRNumber{UInt32},N,N + 1,typeof(indices)} (
77+ indices, length (labels)
78+ )
79+ end
80+
81+ function OneHotArrays. onecold (y:: AnyTracedRArray{T,1} , labels= 1 : length (y)) where {T}
82+ nl = length (labels)
83+ ny = length (y)
84+ nl == ny || throw (
85+ DimensionMismatch (
86+ " onecold got $nl labels for a vector of length $ny , these must agree"
87+ ),
88+ )
89+ imax = argmax (y)
90+ # TODO : error if ymax is nan
91+ labels_arr = Reactant. promote_to (
92+ TracedRArray{Reactant. unwrapped_eltype (labels),1 }, labels
93+ )
94+ return @allowscalar labels_arr[imax]
95+ end
96+
97+ function OneHotArrays. onecold (y:: AnyTracedRArray{T} , labels= 1 : size (y, 1 )) where {T}
98+ nl = length (labels)
99+ ny = size (y, 1 )
100+ nl == ny || throw (
101+ DimensionMismatch (
102+ " onecold got $nl labels for an array with first dimension of size $ny , these must agree" ,
103+ ),
104+ )
105+ labels_arr = Reactant. promote_to (
106+ TracedRArray{Reactant. unwrapped_eltype (labels),1 }, labels
107+ )
108+ labels_expanded = @opcall broadcast_in_dim (
109+ labels_arr, Int64[1 ], Int64[nl, size (y)[2 : end ]. .. ]
110+ )
111+ return ReactantCore. materialize_traced_array (
112+ vec (getindex (labels_expanded, argmax (y; dims= 1 )))
113+ )
114+ end
115+
48116end
0 commit comments