@@ -114,7 +114,7 @@ and [`onecold`](@ref) to reverse either of these, as well as to generalise `argm
114114
115115# Examples
116116```jldoctest
117- julia> β = Flux.onehot(:b, [ :a, :b, :c] )
117+ julia> β = Flux.onehot(:b, ( :a, :b, :c) )
1181183-element OneHotVector(::UInt32) with eltype Bool:
119119 ⋅
120120 1
@@ -131,17 +131,24 @@ julia> hcat(αβγ...) # preserves sparsity
131131```
132132"""
133133function onehot (x, labels)
134- i = something ( findfirst ( isequal (x) , labels), 0 )
135- i > 0 || error (" Value $x is not in labels" )
134+ i = _findval (x , labels)
135+ isnothing (i) && error (" Value $x is not in labels" )
136136 OneHotVector {UInt32, length(labels)} (i)
137137end
138138
139139function onehot (x, labels, default)
140- i = something ( findfirst ( isequal (x) , labels), 0 )
141- i > 0 || return onehot (default, labels)
140+ i = _findval (x , labels)
141+ isnothing (i) && return onehot (default, labels)
142142 OneHotVector {UInt32, length(labels)} (i)
143143end
144144
145+ _findval (val, labels) = findfirst (isequal (val), labels)
146+ # Fast unrolled method for tuples:
147+ function _findval (val, labels:: Tuple , i:: Integer = 1 )
148+ ifelse (isequal (val, first (labels)), i, _findval (val, Base. tail (labels), i+ 1 ))
149+ end
150+ _findval (val, labels:: Tuple{} , i:: Integer ) = nothing
151+
145152"""
146153 onehotbatch(xs, labels, [default])
147154
@@ -156,9 +163,12 @@ If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
156163`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
157164i.e. `result[:, k...] == onehot(xs[k...], labels)`.
158165
166+ Note that `xs` can be any iterable, such as a string. And that using a tuple
167+ for `labels` will often speed up construction, certainly for less than 32 classes.
168+
159169# Examples
160170```jldoctest
161- julia> oh = Flux.onehotbatch(collect( "abracadabra") , 'a':'e', 'e')
171+ julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e')
1621725×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
163173 1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1
164174 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅
@@ -173,7 +183,9 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
173183 3 6 15 3 9 3 12 3 6 15 3
174184```
175185"""
176- onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
186+ onehotbatch (ls, labels, default... ) = _onehotbatch (ls, length (labels) < 32 ? Tuple (labels) : labels, default... )
187+ # NB function barier:
188+ _onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
177189
178190"""
179191 onecold(y::AbstractArray, labels = 1:size(y,1))
@@ -190,7 +202,7 @@ the same operation as `argmax(y, dims=1)` but sometimes a different return type.
190202julia> Flux.onecold([false, true, false])
1912032
192204
193- julia> Flux.onecold([0.3, 0.2, 0.5], [ :a, :b, :c] )
205+ julia> Flux.onecold([0.3, 0.2, 0.5], ( :a, :b, :c) )
194206:c
195207
196208julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1
0 commit comments