@@ -28,29 +28,30 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
2828true
2929```
3030"""
31- struct Chain{T}
31+ struct Chain{T<: Union{Tuple, NamedTuple} }
3232 layers:: T
33- Chain (xs... ) = new {typeof(xs)} (xs)
34- function Chain (; kw... )
35- :layers in Base. keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
36- isempty (kw) && return new {Tuple{}} (())
37- new {typeof(values(kw))} (values (kw))
38- end
33+ end
34+
35+ Chain (xs... ) = Chain (xs)
36+ function Chain (; kw... )
37+ :layers in Base. keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
38+ isempty (kw) && return Chain (())
39+ Chain (values (kw))
3940end
4041
4142@forward Chain. layers Base. getindex, Base. length, Base. first, Base. last,
4243 Base. iterate, Base. lastindex, Base. keys
4344
44- functor ( :: Type{<:Chain} , c) = c . layers, ls -> Chain (ls ... )
45+ @ functor Chain
4546
4647applychain (:: Tuple{} , x) = x
4748applychain (fs:: Tuple , x) = applychain (tail (fs), first (fs)(x))
4849
4950(c:: Chain )(x) = applychain (Tuple (c. layers), x)
5051
51- Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i]. .. )
52- Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
53- Chain (; NamedTuple {Base.keys(c)[i]} (Tuple (c. layers)[i])... )
52+ Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i])
53+ Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
54+ Chain (NamedTuple {Base.keys(c)[i]} (Tuple (c. layers)[i]))
5455
5556function Base. show (io:: IO , c:: Chain )
5657 print (io, " Chain(" )
@@ -246,29 +247,23 @@ julia> Flux.outputsize(m3, (5, 11))
246247(7, 11)
247248```
248249"""
249- struct Maxout{FS<: Tuple }
250- over:: FS
251- Maxout (layers... ) = new {typeof(layers)} (layers)
252- end
253-
254- function Maxout (f:: Function , n_alts:: Integer )
255- over = Tuple (f () for _ in 1 : n_alts)
256- return Maxout (over... )
250+ struct Maxout{T<: Tuple }
251+ layers:: T
257252end
253+ Maxout (layers... ) = Maxout (layers)
254+ Maxout (f:: Function , n_alts:: Integer ) = Maxout ((f () for _ in 1 : n_alts). .. )
258255
259256@functor Maxout
260257
261258function (mo:: Maxout )(input:: AbstractArray )
262259 # Perhaps surprisingly, pairwise max broadcast is often faster,
263260 # even with Zygote. See #698 and #1794
264- mapreduce (f -> f (input), (acc, out) -> max .(acc, out), mo. over )
261+ mapreduce (f -> f (input), (acc, out) -> max .(acc, out), mo. layers )
265262end
266263
267- trainable (mo:: Maxout ) = mo. over
268-
269264function Base. show (io:: IO , mo:: Maxout )
270265 print (io, " Maxout(" )
271- _show_layers (io, mo. over )
266+ _show_layers (io, mo. layers )
272267 print (io, " )" )
273268end
274269
415410Create a `Parallel` layer that passes an input array to each path in
416411`layers`, before reducing the output with `connection`.
417412
418- Called with one input `x`, this is equivalent to `reduce( connection, [l(x) for l in layers])`.
419- If called with multiple inputs, they are `zip`ped with the layers , thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
413+ Called with one input `x`, this is equivalent to `connection( [l(x) for l in layers]... )`.
414+ If called with multiple inputs, one is passed to each layer , thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
420415
421416Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
422417These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
@@ -451,7 +446,7 @@ julia> model2[:β] == model2[2]
451446true
452447```
453448"""
454- struct Parallel{F, T}
449+ struct Parallel{F, T<: Union{Tuple, NamedTuple} }
455450 connection:: F
456451 layers:: T
457452end
@@ -461,25 +456,31 @@ function Parallel(connection; kw...)
461456 layers = NamedTuple (kw)
462457 if :layers in Base. keys (layers) || :connection in Base. keys (layers)
463458 throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
464- elseif isempty (layers)
465- Parallel (connection, ())
466459 end
460+ isempty (layers) && return Parallel (connection, ())
467461 Parallel (connection, layers)
468462end
469463
470464@functor Parallel
471465
472- (m:: Parallel )(x) = mapreduce (f -> f (x), m. connection, Tuple (m. layers))
473- (m:: Parallel )(xs... ) = mapreduce ((f, x) -> f (x), m. connection, Tuple (m. layers), xs)
466+ (m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
474467(m:: Parallel )(xs:: Tuple ) = m (xs... )
468+ function (m:: Parallel )(xs... )
469+ nl = length (m. layers)
470+ nx = length (xs)
471+ if nl != nx
472+ throw (ArgumentError (" Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
473+ end
474+ m. connection (map (|> , xs, Tuple (m. layers))... )
475+ end
475476
476477Base. getindex (m:: Parallel , i) = m. layers[i]
477- Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i]. .. )
478+ Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
479+ Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
480+ Parallel (m. connection, NamedTuple {Base.keys(m)[i]} (Tuple (m. layers)[i]))
478481
479482Base. keys (m:: Parallel ) = Base. keys (getfield (m, :layers ))
480483
481- trainable (m:: Parallel ) = (m. connection, m. layers... )
482-
483484function Base. show (io:: IO , m:: Parallel )
484485 print (io, " Parallel(" , m. connection, " , " )
485486 _show_layers (io, m. layers)
0 commit comments