@@ -147,8 +147,12 @@ outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chai
147147
148148# # bypass statistics in normalization layers
149149
150- for layer in (:LayerNorm , :BatchNorm , :InstanceNorm , :GroupNorm )
151- @eval (l:: $layer )(x:: AbstractArray{Nil} ) = x
150+ for layer in (:BatchNorm , :InstanceNorm , :GroupNorm ) # LayerNorm works fine
151+ @eval function (l:: $layer )(x:: AbstractArray{Nil} )
152+ l. chs == size (x, ndims (x)- 1 ) || throw (DimensionMismatch (
153+ string ($ layer, " expected " , l. chs, " channels, but got size(x) == " , size (x))))
154+ x
155+ end
152156end
153157
154158# # fixes for layers that don't work out of the box
@@ -168,3 +172,162 @@ for (fn, Dims) in ((:conv, DenseConvDims),)
168172 end
169173 end
170174end
175+
176+
177+ """
178+ @autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...)
179+
180+ Returns the specified model, with each `_` replaced by an inferred number,
181+ for input of the given `size`.
182+
183+ The unknown sizes are usually the second-last dimension of that layer's input,
184+ which Flux regards as the channel dimension.
185+ (A few layers, `Dense` & [`LayerNorm`](@ref), instead always use the first dimension.)
186+ The underscore may appear as an argument of a layer, or inside a `=>`.
187+ It may be used in further calculations, such as `Dense(_ => _÷4)`.
188+
189+ # Examples
190+ ```
191+ julia> @autosize (3, 1) Chain(Dense(_ => 2, sigmoid), BatchNorm(_, affine=false))
192+ Chain(
193+ Dense(3 => 2, σ), # 8 parameters
194+ BatchNorm(2, affine=false),
195+ )
196+
197+ julia> img = [28, 28];
198+
199+ julia> @autosize (img..., 1, 32) Chain( # size is only needed at runtime
200+ Chain(c = Conv((3,3), _ => 5; stride=2, pad=SamePad()),
201+ p = MeanPool((3,3)),
202+ b = BatchNorm(_),
203+ f = Flux.flatten),
204+ Dense(_ => _÷4, relu, init=Flux.rand32), # can calculate output size _÷4
205+ SkipConnection(Dense(_ => _, relu), +),
206+ Dense(_ => 10),
207+ ) |> gpu # moves to GPU after initialisation
208+ Chain(
209+ Chain(
210+ c = Conv((3, 3), 1 => 5, pad=1, stride=2), # 50 parameters
211+ p = MeanPool((3, 3)),
212+ b = BatchNorm(5), # 10 parameters, plus 10
213+ f = Flux.flatten,
214+ ),
215+ Dense(80 => 20, relu), # 1_620 parameters
216+ SkipConnection(
217+ Dense(20 => 20, relu), # 420 parameters
218+ +,
219+ ),
220+ Dense(20 => 10), # 210 parameters
221+ ) # Total: 10 trainable arrays, 2_310 parameters,
222+ # plus 2 non-trainable, 10 parameters, summarysize 10.469 KiB.
223+
224+ julia> outputsize(ans, (28, 28, 1, 32))
225+ (10, 32)
226+ ```
227+
228+ Limitations:
229+ * While `@autosize (5, 32) Flux.Bilinear(_ => 7)` is OK, something like `Bilinear((_, _) => 7)` will fail.
230+ * While `Scale(_)` and `LayerNorm(_)` are fine (and use the first dimension), `Scale(_,_)` and `LayerNorm(_,_)`
231+ will fail if `size(x,1) != size(x,2)`.
232+ * RNNs won't work: `@autosize (7, 11) LSTM(_ => 5)` fails, because `outputsize(RNN(3=>7), (3,))` also fails, a known issue.
233+ """
234+ macro autosize (size, model)
235+ Meta. isexpr (size, :tuple ) || error (" @autosize's first argument must be a tuple, the size of the input" )
236+ Meta. isexpr (model, :call ) || error (" @autosize's second argument must be something like Chain(layers...)" )
237+ ex = _makelazy (model)
238+ @gensym m
239+ quote
240+ $ m = $ ex
241+ $ outputsize ($ m, $ size)
242+ $ striplazy ($ m)
243+ end |> esc
244+ end
245+
246+ function _makelazy (ex:: Expr )
247+ n = _underscoredepth (ex)
248+ n == 0 && return ex
249+ n == 1 && error (" @autosize doesn't expect an underscore here: $ex " )
250+ n == 2 && return :($ LazyLayer ($ (string (ex)), $ (_makefun (ex)), nothing ))
251+ n > 2 && return Expr (ex. head, ex. args[1 ], map (_makelazy, ex. args[2 : end ])... )
252+ end
253+ _makelazy (x) = x
254+
255+ function _underscoredepth (ex:: Expr )
256+ # Meta.isexpr(ex, :tuple) && :_ in ex.args && return 10
257+ ex. head in (:call , :kw , :(-> ), :block ) || return 0
258+ ex. args[1 ] === :(=> ) && ex. args[2 ] === :_ && return 1
259+ m = maximum (_underscoredepth, ex. args)
260+ m == 0 ? 0 : m+ 1
261+ end
262+ _underscoredepth (ex) = Int (ex === :_ )
263+
264+ function _makefun (ex)
265+ T = Meta. isexpr (ex, :call ) ? ex. args[1 ] : Type
266+ @gensym x s
267+ Expr (:(-> ), x, Expr (:block , :($ s = $ autosizefor ($ T, $ x)), _replaceunderscore (ex, s)))
268+ end
269+
270+ """
271+ autosizefor(::Type, x)
272+
273+ If an `_` in your layer's constructor, used within `@autosize`, should
274+ *not* mean the 2nd-last dimension, then you can overload this.
275+
276+ For instance `autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)`
277+ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
278+ `Dense(2 => 5)` rather than `Dense(3 => 5)`.
279+ """
280+ autosizefor (:: Type , x:: AbstractArray ) = size (x, max (1 , ndims (x)- 1 ))
281+ autosizefor (:: Type{<:Dense} , x:: AbstractArray ) = size (x, 1 )
282+ autosizefor (:: Type{<:LayerNorm} , x:: AbstractArray ) = size (x, 1 )
283+
284+ _replaceunderscore (e, s) = e === :_ ? s : e
285+ _replaceunderscore (ex:: Expr , s) = Expr (ex. head, map (a -> _replaceunderscore (a, s), ex. args)... )
286+
287+ mutable struct LazyLayer
288+ str:: String
289+ make:: Function
290+ layer
291+ end
292+
293+ @functor LazyLayer
294+
295+ function (l:: LazyLayer )(x:: AbstractArray , ys:: AbstractArray... )
296+ l. layer === nothing || return l. layer (x, ys... )
297+ made = l. make (x) # for something like `Bilinear((_,__) => 7)`, perhaps need `make(xy...)`, later.
298+ y = made (x, ys... )
299+ l. layer = made # mutate after we know that call worked
300+ return y
301+ end
302+
303+ function striplazy (m)
304+ fs, re = functor (m)
305+ re (map (striplazy, fs))
306+ end
307+ function striplazy (l:: LazyLayer )
308+ l. layer === nothing || return l. layer
309+ error (" LazyLayer should be initialised, e.g. by outputsize(model, size), before using stiplazy" )
310+ end
311+
312+ # Could make LazyLayer usable outside of @autosize, for instance allow Chain(@lazy Dense(_ => 2))?
313+ # But then it will survive to produce weird structural gradients etc.
314+
315+ function ChainRulesCore. rrule (l:: LazyLayer , x)
316+ l (x), _ -> error (" LazyLayer should never be used within a gradient. Call striplazy(model) first to remove all." )
317+ end
318+ function ChainRulesCore. rrule (:: typeof (striplazy), m)
319+ striplazy (m), _ -> error (" striplazy should never be used within a gradient" )
320+ end
321+
322+ params! (p:: Params , x:: LazyLayer , seen = IdSet ()) = error (" LazyLayer should never be used within params(m). Call striplazy(m) first." )
323+ function Base. show (io:: IO , l:: LazyLayer )
324+ printstyled (io, " LazyLayer(" , color= :light_black )
325+ if l. layer == nothing
326+ printstyled (io, l. str, color= :magenta )
327+ else
328+ printstyled (io, l. layer, color= :cyan )
329+ end
330+ printstyled (io, " )" , color= :light_black )
331+ end
332+
333+ _big_show (io:: IO , l:: LazyLayer , indent:: Int = 0 , name= nothing ) = _layer_show (io, l, indent, name)
0 commit comments