Skip to content

Commit eb15e85

Browse files
committed
Removed string method; changed maps to a comprehensions
1 parent d041014 commit eb15e85

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

src/onehot.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,30 +184,22 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
184184
```
185185
"""
186186
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
187-
onehotbatch(data::AbstractString, labels, default...) = onehotbatch([_ for _ in data], labels, default...)
188187

189188
# NB function barrier:
190189
function _onehotbatch(data, labels)
191-
n_labels = length(labels)
192-
indices = map(x -> _findval(x, labels), data)
190+
indices = [_findval(i, labels) for i in data]
193191
if nothing in indices
194192
unexpected_values = unique(data[indices .== nothing])
195193
error("Values $unexpected_values are not in labels")
196194
end
197-
return OneHotArray(indices, n_labels)
195+
return OneHotArray(indices, length(labels))
198196
end
199197

200198
function _onehotbatch(data, labels, default)
201-
n_labels = length(labels)
202-
indices = map(x -> _findval(x, labels), data)
203-
if nothing in indices
204-
default_index = _findval(default, labels)
205-
isnothing(default_index) && error("Default value $default is not in labels")
206-
replaced_indices = replace(indices, nothing => default_index)
207-
return OneHotArray(replaced_indices, n_labels)
208-
else
209-
return OneHotArray(indices, n_labels)
210-
end
199+
default_index = _findval(default, labels)
200+
isnothing(default_index) && error("Default value $default is not in labels")
201+
indices = [isnothing(_findval(i, labels)) ? default_index : _findval(i, labels) for i in data]
202+
return OneHotArray(indices, length(labels))
211203
end
212204

213205
"""

0 commit comments

Comments
 (0)