3838
3939Chain (xs... ) = Chain (xs)
4040function Chain (; kw... )
41- :layers in Base . keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
41+ :layers in keys (kw) && throw (ArgumentError (" a Chain cannot have a named layer called `layers`" ))
4242 isempty (kw) && return Chain (())
4343 Chain (values (kw))
4444end
6767
6868Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i])
6969Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
70- Chain (NamedTuple {Base. keys(c)[i]} (Tuple (c. layers)[i]))
70+ Chain (NamedTuple {keys(c)[i]} (Tuple (c. layers)[i]))
7171function Base. show (io:: IO , c:: Chain )
7272 print (io, " Chain(" )
7373 _show_layers (io, c. layers)
487487Parallel (connection, layers... ) = Parallel (connection, layers)
488488function Parallel (connection; kw... )
489489 layers = NamedTuple (kw)
490- if :layers in Base . keys (layers) || :connection in Base . keys (layers)
490+ if :layers in keys (layers) || :connection in keys (layers)
491491 throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
492492 end
493493 isempty (layers) && return Parallel (connection, ())
@@ -498,28 +498,138 @@ end
498498
499499(m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
500500(m:: Parallel )(xs:: Tuple ) = m (xs... )
501- function (m:: Parallel )(xs... )
502- nl = length (m. layers)
503- nx = length (xs)
504- if nl != nx
501+
502+ function _parallel_check (layers, xs)
503+ nl = length (layers)
504+ nx = length (xs)
505+ if (nl != nx)
505506 throw (ArgumentError (" Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
506507 end
508+ end
509+ ChainRulesCore. @non_differentiable _parallel_check (nl, nx)
510+
511+ function (m:: Parallel )(xs... )
512+ _parallel_check (m. layers, xs)
507513 m. connection (map (|> , xs, Tuple (m. layers))... )
508514end
509515
510516Base. getindex (m:: Parallel , i) = m. layers[i]
511517Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
512518Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
513- Parallel (m. connection, NamedTuple {Base. keys(m)[i]} (Tuple (m. layers)[i]))
519+ Parallel (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
514520
515- Base. keys (m:: Parallel ) = Base . keys (getfield (m, :layers ))
521+ Base. keys (m:: Parallel ) = keys (getfield (m, :layers ))
516522
517523function Base. show (io:: IO , m:: Parallel )
518524 print (io, " Parallel(" , m. connection, " , " )
519525 _show_layers (io, m. layers)
520526 print (io, " )" )
521527end
522528
529+ """
530+ PairwiseFusion(connection, layers...)
531+
532+ ## Arguments
533+
534+ - `connection`: A function taking 2 inputs and combining them into a single output
535+ - `layers`: The layers whose outputs are combined
536+
537+ ## Inputs
538+
539+ This layer behaves differently based on input type:
540+
541+ 1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
542+ then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
543+ Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
544+ may be drawn as:
545+ ```
546+ x1 → layer1 → y1 ↘
547+ connection → layer2 → y2 ↘
548+ x2 ↗ connection → layer3 → y3
549+ x3 ↗
550+ ```
551+ ... or written as:
552+ ```julia
553+ y1 = layer1(x1)
554+ y2 = layer2(connection(x2, y1))
555+ y3 = layer3(connection(x3, y2))
556+ ```
557+
558+ 2. With just one input, each layer receives the same `x` combined with the previous output.
559+ Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:
560+
561+ ```julia
562+ y[1] == layers[1](x)
563+ for i in 2:length(layers)
564+ y[i] == connection(x, layers[i](y[i-1]))
565+ end
566+ ```
567+
568+ ## Returns
569+
570+ A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
571+ """
572+ struct PairwiseFusion{F, T<: Union{Tuple, NamedTuple} }
573+ connection:: F
574+ layers:: T
575+ end
576+
577+ PairwiseFusion (connection, layers... ) = PairwiseFusion (connection, layers)
578+ function PairwiseFusion (connection; kw... )
579+ layers = NamedTuple (kw)
580+ if :layers in keys (layers) || :connection in keys (layers)
581+ throw (ArgumentError (" a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`" ))
582+ end
583+ isempty (layers) && return PairwiseFusion (connection, ())
584+ PairwiseFusion (connection, layers)
585+ end
586+
587+ function _pairwise_check (x, layers, T)
588+ lx = length (x)
589+ N = length (layers)
590+ if T <: Tuple && lx != N
591+ throw (ArgumentError (" PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs" ))
592+ end
593+ end
594+ ChainRulesCore. @non_differentiable _pairwise_check (lx, N, T)
595+
596+ function (m:: PairwiseFusion )(x:: T ) where {T}
597+ _pairwise_check (x, m. layers, T)
598+ applypairwisefusion (m. layers, m. connection, x)
599+ end
600+ (m:: PairwiseFusion )(xs... ) = m (xs)
601+
602+ @generated function applypairwisefusion (layers:: Tuple{Vararg{<:Any,N}} , connection, x:: T ) where {N, T}
603+ y_symbols = [gensym () for _ in 1 : (N + 1 )]
604+ getinput (i) = T <: Tuple ? :(x[$ i]) : :x
605+ calls = [:($ (y_symbols[N + 1 ]) = $ (getinput (1 )))]
606+ for i in 1 : N - 1
607+ push! (calls, quote
608+ $ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]))
609+ $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 )))
610+ end )
611+ end
612+ push! (calls, :($ (y_symbols[N]) = layers[$ N]($ (y_symbols[N + 1 ]))))
613+ push! (calls, :(return tuple ($ (Tuple (y_symbols[1 : N])... ))))
614+ return Expr (:block , calls... )
615+ end
616+ applypairwisefusion (layers:: NamedTuple , connection, x) = applypairwisefusion (Tuple (layers), connection, x)
617+
618+ @functor PairwiseFusion
619+
620+ Base. getindex (m:: PairwiseFusion , i) = m. layers[i]
621+ Base. getindex (m:: PairwiseFusion , i:: AbstractVector ) = PairwiseFusion (m. connection, m. layers[i])
622+ Base. getindex (m:: PairwiseFusion{<:Any, <:NamedTuple} , i:: AbstractVector ) =
623+ PairwiseFusion (m. connection, NamedTuple {keys(m)[i]} (Tuple (m. layers)[i]))
624+
625+ Base. keys (m:: PairwiseFusion ) = keys (getfield (m, :layers ))
626+
627+ function Base. show (io:: IO , m:: PairwiseFusion )
628+ print (io, " PairwiseFusion(" , m. connection, " , " )
629+ _show_layers (io, m. layers)
630+ print (io, " )" )
631+ end
632+
523633"""
524634 Embedding(in => out; init=randn)
525635
556666@functor Embedding
557667
558668Embedding ((in, out):: Pair{<:Integer, <:Integer} ; init = randn32) = Embedding (init (out, in))
559-
669+
560670(m:: Embedding )(x:: Integer ) = m. weight[:, x]
561671(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
562672(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
@@ -565,7 +675,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
565675 size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
566676 return m (onecold (x))
567677end
568-
678+
569679function Base. show (io:: IO , m:: Embedding )
570680 print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
571681end
0 commit comments