3838
3939Chain (xs... ) = Chain (xs)
4040function Chain (; kw... )
41- :layers in Base . keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
41+ :layers in keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
4242 isempty (kw) && return Chain (())
4343 Chain (values (kw))
4444end
@@ -498,12 +498,18 @@ end
498498
499499(m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
500500(m:: Parallel )(xs:: Tuple ) = m (xs... )
501- function (m:: Parallel )(xs... )
502- nl = length (m. layers)
503- nx = length (xs)
504- if nl != nx
501+
502+ function _parallel_check (layers, xs)
503+ nl = length (layers)
504+ nx = length (xs)
505+ if (nl != nx)
505506 throw (ArgumentError (" Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
506507 end
508+ end
509+ ChainRulesCore. @non_differentiable _parallel_check (nl, nx)
510+
511+ function (m:: Parallel )(xs... )
512+ _parallel_check (m. layers, xs)
507513 m. connection (map (|> , xs, Tuple (m. layers))... )
508514end
509515
@@ -563,32 +569,36 @@ end
563569
564570A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
565571"""
566- struct PairwiseFusion{F, T <: NamedTuple }
572+ struct PairwiseFusion{F, T<: Union{Tuple, NamedTuple} }
567573 connection:: F
568574 layers:: T
569575end
570576
571- function PairwiseFusion (connection, layers... )
572- names = ntuple (i -> Symbol (" layer_$i " ), length (layers))
573- return PairwiseFusion (connection, NamedTuple {names} (layers))
577+ PairwiseFusion (connection, layers... ) = PairwiseFusion (connection, layers)
578+ function PairwiseFusion (connection; kw... )
579+ layers = NamedTuple (kw)
580+ if :layers in keys (layers) || :connection in keys (layers)
581+ throw (ArgumentError (" a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`" ))
582+ end
583+ isempty (layers) && return PairwiseFusion (connection, ())
584+ PairwiseFusion (connection, layers)
574585end
575586
576- function _pairwise_check (lx, N, T)
587+ function _pairwise_check (x, layers, T)
588+ lx = length (x)
589+ N = length (layers)
577590 if T <: Tuple && lx != N
578591 throw (ArgumentError (" PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs" ))
579592 end
580593end
581594ChainRulesCore. @non_differentiable _pairwise_check (lx, N, T)
582595
583596function (m:: PairwiseFusion )(x:: T ) where {T}
584- lx = length (x)
585- N = length (m. layers)
586- _pairwise_check (lx, N, T)
597+ _pairwise_check (x, m. layers, T)
587598 applypairwisefusion (m. layers, m. connection, x)
588599end
589600
590- @generated function applypairwisefusion (layers:: NamedTuple{names} , connection, x:: T ) where {names, T}
591- N = length (names)
601+ @generated function applypairwisefusion (layers:: Tuple{Vararg{<:Any,N}} , connection, x:: T ) where {N, T}
592602 y_symbols = [gensym () for _ in 1 : (N + 1 )]
593603 getinput (i) = T <: Tuple ? :(x[$ i]) : :x
594604 calls = [:($ (y_symbols[N + 1 ]) = $ (getinput (1 )))]
@@ -602,10 +612,12 @@ end
602612 push! (calls, :(return tuple ($ (Tuple (y_symbols[1 : N])... ))))
603613 return Expr (:block , calls... )
604614end
615+ applypairwisefusion (layers:: NamedTuple , connection, x) = applypairwisefusion (Tuple (layers), connection, x)
605616
606617@functor PairwiseFusion
607618
608619Base. getindex (m:: PairwiseFusion , i) = m. layers[i]
620+ Base. getindex (m:: PairwiseFusion , i:: AbstractVector ) = PairwiseFusion (m. connection, m. layers[i])
609621Base. getindex (m:: PairwiseFusion{<:Any, <:NamedTuple} , i:: AbstractVector ) =
610622 PairwiseFusion (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
611623
0 commit comments