|
1 | | -istraining() = false |
2 | 1 |
|
3 | | -ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) |
4 | | - |
5 | | -_isactive(m) = isnothing(m.active) ? istraining() : m.active |
| 2 | +_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active |
6 | 3 |
|
7 | 4 | _dropout_shape(s, ::Colon) = size(s) |
8 | 5 | _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) |
|
107 | 104 | trainable(a::Dropout) = (;) |
108 | 105 |
|
109 | 106 | function (a::Dropout)(x) |
110 | | - _isactive(a) || return x |
| 107 | + _isactive(a, x) || return x |
111 | 108 | return dropout(a.rng, x, a.p; dims=a.dims, active=true) |
112 | 109 | end |
113 | 110 |
|
@@ -162,7 +159,7 @@ AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng) |
162 | 159 | trainable(a::AlphaDropout) = (;) |
163 | 160 |
|
164 | 161 | function (a::AlphaDropout)(x::AbstractArray{T}) where T |
165 | | - _isactive(a) || return x |
| 162 | + _isactive(a, x) || return x |
166 | 163 | p = a.p |
167 | 164 | iszero(p) && return x |
168 | 165 | isone(p) && return sign.(x) .* T(0) |
|
242 | 239 | function _norm_layer_forward( |
243 | 240 | l, x::AbstractArray{T, N}; reduce_dims, affine_shape, |
244 | 241 | ) where {T, N} |
245 | | - if !_isactive(l) && l.track_stats # testmode with tracked stats |
| 242 | + if !_isactive(l, x) && l.track_stats # testmode with tracked stats |
246 | 243 | stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) |
247 | 244 | μ = reshape(l.μ, stats_shape) |
248 | 245 | σ² = reshape(l.σ², stats_shape) |
|
0 commit comments