@@ -8,6 +8,7 @@ on a given input.
88`m[1:3](x)` will calculate the output of the first three layers.
99
1010# Examples
11+
1112```jldoctest
1213julia> m = Chain(x -> x^2, x -> x+1);
1314
@@ -428,3 +429,55 @@ function Base.show(io::IO, m::Parallel)
428429 join (io, m. layers, " , " )
429430 print (io, " )" )
430431end
432+
433+ """
434+ Embedding(in, out; init=randn)
435+
436+ A lookup table that stores embeddings of dimension `out`
437+ for a vocabulary of size `in`.
438+
439+ This layers is often used to store word embeddings and retrieve them using indices.
440+ The input to the layer can be either a vector of indexes
441+ or the corresponding [onehot encoding](@ref Flux.OneHotArray).
442+
443+ # Examples
444+
445+ ```julia-repl
446+ julia> vocab_size, embed_size = 1000, 4;
447+
448+ julia> model = Embedding(vocab_size, embed_size)
449+ Embedding(1000, 4)
450+
451+ julia> vocab_idxs = [1, 722, 53, 220, 3]
452+
453+ julia> x = OneHotMatrix(vocab_idxs, vocab_size);
454+
455+ julia> model(x)
456+ 4×5 Matrix{Float32}:
457+ 0.91139 0.670462 0.463217 0.670462 0.110932
458+ 0.247225 -0.0823874 0.698694 -0.0823874 0.945958
459+ -0.393626 -0.590136 -0.545422 -0.590136 0.77743
460+ -0.497621 0.87595 -0.870251 0.87595 -0.772696
461+ ```
462+
463+ julia> model(vocab_idxs) == model(x)
464+ true
465+ """
466+ struct Embedding{W}
467+ weight:: W
468+ end
469+
470+ @functor Embedding
471+
472+ function Embedding (in:: Integer , out:: Integer ;
473+ init = (i... ) -> randn (Float32, i... ))
474+ return Embedding (init (out, in))
475+ end
476+
477+ (m:: Embedding )(x:: Union{OneHotVector, OneHotMatrix} ) = m. weight * x # equivalent to m.weight[:,onecold(x)]
478+ (m:: Embedding )(x:: Union{Int,AbstractVector} ) = m. weight[:, x]
479+ (m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
480+
481+ function Base. show (io:: IO , m:: Embedding )
482+ print (io, " Embedding($(size (m. weight, 2 )) , $(size (m. weight, 1 )) )" )
483+ end
0 commit comments