@@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66expand (N, i:: Tuple ) = i
77expand (N, i:: Integer ) = ntuple (_ -> i, N)
88
9+ conv_reshape_bias (c) = c. bias isa AbstractVector ?
10+ reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 ) :
11+ c. bias
12+
913"""
1014 SamePad()
1115
6165
6266Keywords to control initialization of the layer:
6367* `init` - Function used to generate initial weights. Defaults to `glorot_uniform`.
64- * `bias` - Initial bias is zero by default, this can be disabled entirely by setting it to
65- `false`, or another vector explicitly as `bias = randn(Float32, out)`.
68+ * `bias` - The initial bias vector is all zero by default. Trainable bias can be disabled entirely
69+ by setting this to `false`, or another vector can be provided such as `bias = randn(Float32, out)`.
6670
6771See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
6872
159163@functor Conv
160164
161165function (c:: Conv )(x:: AbstractArray )
162- b = reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
163166 σ = NNlib. fast_act (c. σ, x)
164167 cdims = DenseConvDims (x, c. weight; stride = c. stride, padding = c. pad, dilation = c. dilation, groups = c. groups)
165- σ .(conv (x, c. weight, cdims) .+ b )
168+ σ .(conv (x, c. weight, cdims) .+ conv_reshape_bias (c) )
166169end
167170
168171_channels_in (l :: Conv ) = size (l. weight, ndims (l. weight)- 1 ) * l. groups
@@ -183,7 +186,7 @@ function _print_conv_opt(io::IO, l)
183186 if hasproperty (l, :groups )
184187 (l. groups == 1 ) || print (io, " , groups=" , l. groups)
185188 end
186- (l. bias isa Zeros ) && print (io, " , bias=false" )
189+ (l. bias === false ) && print (io, " , bias=false" )
187190end
188191
189192"""
277280@nograd conv_transpose_dims
278281
279282function (c:: ConvTranspose )(x:: AbstractArray )
280- b = reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
281283 σ = NNlib. fast_act (c. σ, x)
282284 cdims = conv_transpose_dims (c, x)
283- σ .(∇conv_data (x, c. weight, cdims) .+ b )
285+ σ .(∇conv_data (x, c. weight, cdims) .+ conv_reshape_bias (c) )
284286end
285287
286288function Base. show (io:: IO , l:: ConvTranspose )
@@ -372,10 +374,9 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
372374 init = glorot_uniform) where N = init (filter... , div (ch[2 ], ch[1 ]), ch[1 ])
373375
374376function (c:: DepthwiseConv )(x)
375- b = reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
376377 σ = NNlib. fast_act (c. σ, x)
377378 cdims = DepthwiseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
378- σ .(depthwiseconv (x, c. weight, cdims) .+ b )
379+ σ .(depthwiseconv (x, c. weight, cdims) .+ conv_reshape_bias (c) )
379380end
380381
381382function Base. show (io:: IO , l:: DepthwiseConv )
@@ -453,10 +454,9 @@ function crosscor(x, w, ddims::DenseConvDims)
453454end
454455
455456function (c:: CrossCor )(x:: AbstractArray )
456- b = reshape (c. bias, map (_-> 1 , c. stride)... , :, 1 )
457457 σ = NNlib. fast_act (c. σ, x)
458458 cdims = DenseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
459- σ .(crosscor (x, c. weight, cdims) .+ b )
459+ σ .(crosscor (x, c. weight, cdims) .+ conv_reshape_bias (c) )
460460end
461461
462462function Base. show (io:: IO , l:: CrossCor )
0 commit comments