@@ -644,35 +644,45 @@ function Base.show(io::IO, m::PairwiseFusion)
644644end
645645
646646"""
647- Embedding(in => out; init=randn )
647+ Embedding(in => out; init=randn32 )
648648
649649A lookup table that stores embeddings of dimension `out`
650- for a vocabulary of size `in`.
650+ for a vocabulary of size `in`, as a trainable matrix .
651651
652652This layer is often used to store word embeddings and retrieve them using indices.
653- The input to the layer can be either a vector of indexes
654- or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
653+ The input to the layer can be a vocabulary index in `1:in`, an array of indices,
654+ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
655+
656+ For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
657+ For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.
655658
656659# Examples
657660```jldoctest
658- julia> vocab_size, embed_size = 1000, 4;
659-
660- julia> model = Flux.Embedding(vocab_size => embed_size)
661- Embedding(1000 => 4) # 4_000 parameters
662-
663- julia> vocab_idxs = [1, 722, 53, 220, 3];
664-
665- julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x)
666- "1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool"
667-
668- julia> model(x) |> summary
669- "4×5 Matrix{Float32}"
670-
671- julia> model(vocab_idxs) == model(x)
661+ julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
662+ Embedding(26 => 4) # 104 parameters
663+
664+ julia> emb(2) # one column of e.weight (here not random!)
665+ 4-element Vector{Float32}:
666+ 0.0
667+ 22.0
668+ 0.0
669+ 0.0
670+
671+ julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
672+ 4×7 Matrix{Float32}:
673+ 0.0 22.0 0.0 0.0 0.0 0.0 0.0
674+ 0.0 0.0 0.0 0.0 0.0 0.0 0.0
675+ 22.0 0.0 0.0 0.0 0.0 0.0 0.0
676+ 0.0 0.0 0.0 0.0 22.0 0.0 0.0
677+
678+ julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
672679true
680+
681+ julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
682+ (4, 10, 1, 12)
673683```
674684"""
675- struct Embedding{W}
685+ struct Embedding{W<: AbstractMatrix }
676686 weight:: W
677687end
678688
@@ -684,10 +694,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
684694(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
685695(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
686696
687- function (m:: Embedding )(x:: Union{OneHotVector{T,L}, OneHotMatrix{T,L}} ) where {T,L}
688- size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
689- return m (onecold (x))
690- end
697+ (m:: Embedding )(x:: AbstractVector{Bool} ) = m. weight * x # usually OneHotVector
698+ (m:: Embedding )(x:: AbstractMatrix{Bool} ) = m. weight * x # usually OneHotMatrix
699+ (m:: Embedding )(x:: AbstractArray{Bool} ) = reshape (m (reshape (x, size (x,1 ), :)), :, size (x)[2 : end ]. .. )
691700
692701function Base. show (io:: IO , m:: Embedding )
693702 print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
0 commit comments