diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 8aaf4e7df0..e76efea60a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -172,31 +172,40 @@ function Base.show(io::IO, l::Dense) end """ - Diagonal(α, β) - Diagonal(size::Integer...) + Diagonal(size::Integer...; bias=true, init=ones32) + Diagonal(scale::AbstractArray, [bias]) Create an element-wise linear layer, which performs - y = α .* x .+ β + y = scale .* x .+ bias -The learnable arrays are initialised `α = ones(Float32, size)` and -`β = zeros(Float32, size)`. +with no activation function. + +The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`, +with `init=ones32` by default. You may specify the function `init`, +turn off trainable bias with `bias=false`, or provide the array(s) explicitly. Used by [`LayerNorm`](@ref). """ -struct Diagonal{T} - α::T - β::T +struct Diagonal{A<:AbstractArray, B} + scale::A + bias::B + function Diagonal(W::M, bias = true) where M<:AbstractArray + b = create_bias(W, bias, size(W)...) + new{M, typeof(b)}(W, b) + end end -Diagonal(sz::Integer...) = Diagonal(ones32(sz...), zeros32(sz...)) +Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias) @functor Diagonal -(a::Diagonal)(x) = a.α .* x .+ a.β +(a::Diagonal)(x) = a.scale .* x .+ a.bias function Base.show(io::IO, l::Diagonal) - print(io, "Diagonal(", join(size(l.α), ", "), ")") + print(io, "Diagonal(", join(size(l.scale), ", ")) + l.bias == false && print(io, "; bias=false") + print(io, ")") end """ diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 5befed57b4..0c12b22d11 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -91,16 +91,17 @@ import Flux: activations @test length(Flux.Diagonal(10)(randn(10))) == 10 @test length(Flux.Diagonal(10)(1)) == 10 @test length(Flux.Diagonal(10)(randn(1))) == 10 + @test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10 @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] @test Flux.Diagonal(2)([1,2]) == [1,2] - @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] + @test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4] @test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4) @test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2,3,4)(rand(2,3,4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2,3)(rand(2,1,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4) end @testset "Maxout" begin