@@ -644,14 +644,17 @@ function Base.show(io::IO, m::PairwiseFusion)
644644end
645645
646646"""
647- Embedding(in => out; init=randn )
647+ Embedding(in => out; init=randn32 )
648648
649649A lookup table that stores embeddings of dimension `out`
650- for a vocabulary of size `in`.
650+ for a vocabulary of size `in`, as a trainable matrix .
651651
652652This layer is often used to store word embeddings and retrieve them using indices.
653- The input to the layer can be either a vector of indexes
654- or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653+ The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654+ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655+
656+ For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657+ For one-hot `x`, the result is of size `(out, size(x)[2:end]...)`.
655658
656659# Examples
657660```jldoctest
@@ -684,10 +687,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684687(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
685688(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
686689
687- function (m:: Embedding )(x:: Union{OneHotVector{T,L}, OneHotMatrix{T,L}} ) where {T,L}
688- size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
689- return m (onecold (x))
690- end
690+ (m:: Embedding )(x:: AbstractVector{Bool} ) = m. weight * x # usually OneHotVector
691+ (m:: Embedding )(x:: AbstractMatrix{Bool} ) = m. weight * x # usually OneHotMatrix
692+ (m:: Embedding )(x:: AbstractArray{Bool} ) = reshape (m (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
691693
692694function Base. show (io:: IO , m:: Embedding )
693695 print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
0 commit comments