@@ -184,30 +184,22 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
184184```
185185"""
186186onehotbatch (data, labels, default... ) = _onehotbatch (data, length (labels) < 32 ? Tuple (labels) : labels, default... )
187- onehotbatch (data:: AbstractString , labels, default... ) = onehotbatch ([_ for _ in data], labels, default... )
188187
189188# NB function barrier:
190189function _onehotbatch (data, labels)
191- n_labels = length (labels)
192- indices = map (x -> _findval (x, labels), data)
190+ indices = [_findval (i, labels) for i in data]
193191 if nothing in indices
194192 unexpected_values = unique (data[indices .== nothing ])
195193 error (" Values $unexpected_values are not in labels" )
196194 end
197- return OneHotArray (indices, n_labels )
195+ return OneHotArray (indices, length (labels) )
198196end
199197
200198function _onehotbatch (data, labels, default)
201- n_labels = length (labels)
202- indices = map (x -> _findval (x, labels), data)
203- if nothing in indices
204- default_index = _findval (default, labels)
205- isnothing (default_index) && error (" Default value $default is not in labels" )
206- replaced_indices = replace (indices, nothing => default_index)
207- return OneHotArray (replaced_indices, n_labels)
208- else
209- return OneHotArray (indices, n_labels)
210- end
199+ default_index = _findval (default, labels)
200+ isnothing (default_index) && error (" Default value $default is not in labels" )
201+ indices = [isnothing (_findval (i, labels)) ? default_index : _findval (i, labels) for i in data]
202+ return OneHotArray (indices, length (labels))
211203end
212204
213205"""
0 commit comments