@@ -183,9 +183,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
183183 3 6 15 3 9 3 12 3 6 15 3
184184```
185185"""
186- onehotbatch (ls, labels, default... ) = _onehotbatch (ls, length (labels) < 32 ? Tuple (labels) : labels, default... )
187- # NB function barier:
188- _onehotbatch (ls, labels, default... ) = batch ([onehot (l, labels, default... ) for l in ls])
186+ onehotbatch (data, labels, default... ) = _onehotbatch (data, length (labels) < 32 ? Tuple (labels) : labels, default... )
187+
188+ function _onehotbatch (data, labels)
189+ indices = UInt32[something (_findval (i, labels), 0 ) for i in data]
190+ if 0 in indices
191+ for x in data
192+ isnothing (_findval (x, labels)) && error (" Value $x not found in labels" )
193+ end
194+ end
195+ return OneHotArray (indices, length (labels))
196+ end
197+
198+ function _onehotbatch (data, labels, default)
199+ default_index = _findval (default, labels)
200+ isnothing (default_index) && error (" Default value $default is not in labels" )
201+ indices = UInt32[something (_findval (i, labels), default_index) for i in data]
202+ return OneHotArray (indices, length (labels))
203+ end
189204
190205"""
191206 onecold(y::AbstractArray, labels = 1:size(y,1))
0 commit comments