Skip to content

Commit 2da6fe3

Browse files
feat: support onehotbatch/onecold (#1794)
* fix: try always downloading libtpu on ci * feat: version check for libtpu * feat: support onehotbatch/onecold * fix: onehotbatch * Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: use batched findfirst * Apply suggestion from @avik-pal --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b5e3b5a commit 2da6fe3

File tree

2 files changed

+106
-8
lines changed

2 files changed

+106
-8
lines changed

ext/ReactantOneHotArraysExt.jl

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
module ReactantOneHotArraysExt
22

3-
using OneHotArrays: OneHotArray
4-
using Reactant: Reactant, TracedRArray, TracedRNumber, Ops
3+
using GPUArraysCore: @allowscalar
4+
using OneHotArrays: OneHotArrays, OneHotArray
5+
using Reactant: Reactant, AnyTracedRArray, TracedRArray, TracedRNumber
56
using ReactantCore: ReactantCore
67
using Reactant.Ops: @opcall
78

9+
__compatible_eltype(::Type{T}, ::Type{U}) where {T,U} = T
10+
function __compatible_eltype(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{U}}) where {T,U}
11+
return TracedRNumber{T}
12+
end
13+
__compatible_eltype(::Type{TracedRNumber{T}}, ::Type{U}) where {T,U} = T
14+
__compatible_eltype(::Type{T}, ::Type{TracedRNumber{U}}) where {T,U} = TracedRNumber{T}
15+
816
function Reactant.traced_type_inner(
917
@nospecialize(_::Type{OneHotArray{T,N,Np1,I}}),
1018
seen,
@@ -14,12 +22,7 @@ function Reactant.traced_type_inner(
1422
@nospecialize(runtime)
1523
) where {T,N,Np1,I}
1624
I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime)
17-
T2 = if eltype(I2) <: Reactant.TracedRNumber && !(T <: Reactant.TracedRNumber)
18-
Reactant.TracedRNumber{T}
19-
else
20-
T
21-
end
22-
return OneHotArray{T2,N,Np1,I2}
25+
return OneHotArray{__compatible_eltype(T, eltype(I2)),N,Np1,I2}
2326
end
2427

2528
function ReactantCore.materialize_traced_array(r::OneHotArray)
@@ -45,4 +48,69 @@ function Base.Array(
4548
return Array(reshape(Array(r.indices), 1, size(r.indices)...) .== 1:(r.nlabels))
4649
end
4750

51+
function OneHotArrays.onehotbatch(data::AnyTracedRArray{<:Any,N}, labels) where {N}
52+
# TODO: add checkbounds once we support that with TracedRNumber
53+
labels_expanded = @opcall broadcast_in_dim(
54+
Reactant.promote_to(
55+
TracedRArray{Reactant.unwrapped_eltype(labels),1},
56+
ReactantCore.materialize_traced_array(vec(labels)),
57+
),
58+
Int64[1],
59+
[length(labels), size(data)...],
60+
)
61+
data = ReactantCore.materialize_traced_array(reshape(data, 1, size(data)...))
62+
indices = UInt32.(@opcall(findfirst(data .== labels_expanded; dimension=1)))
63+
return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}(
64+
indices, length(labels)
65+
)
66+
end
67+
68+
function OneHotArrays.onehotbatch(
69+
data::AnyTracedRArray{<:Integer,N}, labels::AbstractUnitRange{<:Integer}
70+
) where {N}
71+
# TODO: add checkbounds once we support that with TracedRNumber
72+
indices = map(
73+
TracedRNumber{UInt32} Base.Fix2(+, 1 - first(labels)),
74+
ReactantCore.materialize_traced_array(data),
75+
)
76+
return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}(
77+
indices, length(labels)
78+
)
79+
end
80+
81+
function OneHotArrays.onecold(y::AnyTracedRArray{T,1}, labels=1:length(y)) where {T}
82+
nl = length(labels)
83+
ny = length(y)
84+
nl == ny || throw(
85+
DimensionMismatch(
86+
"onecold got $nl labels for a vector of length $ny, these must agree"
87+
),
88+
)
89+
imax = argmax(y)
90+
# TODO: error if ymax is nan
91+
labels_arr = Reactant.promote_to(
92+
TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels
93+
)
94+
return @allowscalar labels_arr[imax]
95+
end
96+
97+
function OneHotArrays.onecold(y::AnyTracedRArray{T}, labels=1:size(y, 1)) where {T}
98+
nl = length(labels)
99+
ny = size(y, 1)
100+
nl == ny || throw(
101+
DimensionMismatch(
102+
"onecold got $nl labels for an array with first dimension of size $ny, these must agree",
103+
),
104+
)
105+
labels_arr = Reactant.promote_to(
106+
TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels
107+
)
108+
labels_expanded = @opcall broadcast_in_dim(
109+
labels_arr, Int64[1], Int64[nl, size(y)[2:end]...]
110+
)
111+
return ReactantCore.materialize_traced_array(
112+
vec(getindex(labels_expanded, argmax(y; dims=1)))
113+
)
114+
end
115+
48116
end

test/integration/onehotarrays.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,33 @@ end
3131
@test res_ra res
3232
end
3333
end
34+
35+
@testset "onehotbatch/onecold" begin
36+
x = Int32[10, 20, 30, 10, 10]
37+
x_ra = Reactant.to_rarray(x)
38+
labels = Int32(10):Int32(10):Int32(40)
39+
res_ra = @jit onehotbatch(x_ra, labels)
40+
res = onehotbatch(x, labels)
41+
@test Array(res_ra) res
42+
43+
x = rand(10:10:40, 2, 3, 5)
44+
x_ra = Reactant.to_rarray(x)
45+
labels = reshape([10, 20, 30, 40], 2, 2)
46+
res = onehotbatch(x, labels)
47+
res_ra = @jit onehotbatch(x_ra, labels)
48+
@test Array(res_ra) res
49+
50+
x = Int32[1, 2, 3, 1, 1]
51+
x_ra = Reactant.to_rarray(x)
52+
labels = Int32(1):Int32(4)
53+
res_ra = @jit onehotbatch(x_ra, labels)
54+
res = onehotbatch(x, labels)
55+
@test Array(res_ra) res
56+
57+
vec_ra = Reactant.to_rarray(Float32[0.3, 0.2, 0.5])
58+
@test @jit(onecold(vec_ra)) == 3
59+
60+
dense_ra = Reactant.to_rarray(Array(res))
61+
oc_res = onecold(res)
62+
@test @jit(onecold(dense_ra)) == oc_res
63+
end

0 commit comments

Comments
 (0)