@@ -171,40 +171,69 @@ function Base.show(io::IO, l::Dense)
171171end
172172
173173"""
174- Diagonal (size::Integer...; σ = identity, bias=true, init=ones32)
175- Diagonal (scale::AbstractArray, [bias, activation ])
174+ Scale (size::Integer..., σ= identity; bias=true, init=ones32)
175+ Scale (scale::AbstractArray, [bias, σ ])
176176
177- Create an element-wise linear layer, which performs
177+ Create an element-wise layer, whose forward pass is given by:
178178
179179 y = σ.(scale .* x .+ bias)
180180
181+ This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).
182+
181183The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
182184with `init=ones32` by default. You may specify the function `init`,
183185turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
184186
185- Used by [`LayerNorm`](@ref).
187+ Used by [`LayerNorm`](@ref) with `affine=true`.
188+
189+ # Examples
190+ ```jldoctest
191+ julia> a = Flux.Scale(2)
192+ Scale(2) # 4 parameters
193+
194+ julia> Flux.params(a)
195+ Params([Float32[1.0, 1.0], Float32[0.0, 0.0]])
196+
197+ julia> a([1 2 3])
198+ 2×3 Matrix{Float32}:
199+ 1.0 2.0 3.0
200+ 1.0 2.0 3.0
201+
202+ julia> b = Flux.Scale([1 2 3 4], false, abs2)
203+ Scale(1, 4, abs2; bias=false) # 4 parameters
204+
205+ julia> b([1, 10])
206+ 2×4 Matrix{Int64}:
207+ 1 4 9 16
208+ 100 400 900 1600
209+
210+ julia> Flux.params(b)
211+ Params([[1 2 3 4]])
212+ ```
186213"""
187- struct Diagonal{ A<: AbstractArray , B, F }
214+ struct Scale{F, A<: AbstractArray , B}
188215 scale:: A
189216 bias:: B
190217 σ:: F
191- function Diagonal (W :: M , bias = true , σ:: F = identity) where {M <: AbstractArray , F}
192- b = create_bias (W , bias, size (W )... )
193- new {M, typeof(b), F} (W , b, σ)
218+ function Scale (scale :: A , bias:: B = true , σ:: F = identity) where {A <: AbstractArray , B <: Union{Bool, AbstractArray} , F}
219+ b = create_bias (scale , bias, size (scale )... )
220+ new {F, A, typeof(b)} (scale , b, σ)
194221 end
195222end
196223
197- Diagonal (sz:: Integer... ; σ = identity, bias = true , init = ones32) = Diagonal (init (sz... ), bias, σ)
224+ Scale (s1:: Integer , s23:: Integer... ; bias = true , init = ones32, _act = identity) = Scale (init (s1, s23... ), bias, _act)
225+ Scale (size_act... ; bias = true , init = ones32) = Scale (size_act[1 : end - 1 ]. .. ; bias, init, _act = size_act[end ])
198226
199- @functor Diagonal
227+ @functor Scale
200228
201- function (a:: Diagonal )(x:: AbstractArray )
229+ function (a:: Scale )(x:: AbstractArray )
202230 σ = NNlib. fast_act (a. σ, x) # replaces tanh => tanh_fast, etc
203- return σ === typeof (identity) ? a . scale .* x .+ a . bias : σ .(a. scale .* x .+ a. bias)
231+ σ .(a. scale .* x .+ a. bias)
204232end
205233
206- function Base. show (io:: IO , l:: Diagonal )
207- print (io, " Diagonal(" , join (size (l. scale), " , " ))
234+ function Base. show (io:: IO , l:: Scale )
235+ print (io, " Scale(" , join (size (l. scale), " , " ))
236+ l. σ == identity || print (io, " , " , l. σ)
208237 l. bias == false && print (io, " ; bias=false" )
209238 print (io, " )" )
210239end
0 commit comments