Skip to content

Commit 55572ab

Browse files
committed
using something instead of calling findval twice; fixing case where findval would return nothing
1 parent 5e5dc7d commit 55572ab

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/onehot.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,10 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
185185
"""
186186
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
187187

188-
# NB function barrier:
189188
function _onehotbatch(data, labels)
190-
indices = UInt32[_findval(i, labels) for i in data]
191-
if nothing in indices
192-
unexpected_values = unique(data[indices .== nothing])
189+
indices = UInt32[something(_findval(i, labels), 0) for i in data]
190+
if 0 in indices
191+
unexpected_values = unique(data[indices .== 0])
193192
error("Values $unexpected_values are not in labels")
194193
end
195194
return OneHotArray(indices, length(labels))
@@ -198,7 +197,7 @@ end
198197
function _onehotbatch(data, labels, default)
199198
default_index = _findval(default, labels)
200199
isnothing(default_index) && error("Default value $default is not in labels")
201-
indices = UInt32[isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data]
200+
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
202201
return OneHotArray(indices, length(labels))
203202
end
204203

test/onehot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using Test
2424
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c))
2525
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)
2626
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e)
27-
@test_throws Exception onehotbatch([:a, :e], (:a, :b, :c), :d)
27+
@test_throws Exception onehotbatch([:a, :b], (:a, :b, :c), :d)
2828

2929
floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf)
3030
@test onecold(onehot(0.0, floats)) == 1

0 commit comments

Comments
 (0)