564564
565565## Returns
566566
567- `PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
567+ A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568568"""
569569struct PairwiseFusion{F, T <: NamedTuple }
570570 connection:: F
@@ -576,12 +576,17 @@ function PairwiseFusion(connection, layers...)
576576 return PairwiseFusion (connection, NamedTuple {names} (layers))
577577end
578578
579- function (m:: PairwiseFusion )(x:: T ) where {T}
580- lx = length (x)
581- N = length (m. layers)
579+ function _pairwise_check (lx, N, T)
582580 if T <: Tuple && lx != N
583581 throw (ArgumentError (" PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs" ))
584582 end
583+ end
584+ ChainRulesCore. @non_differentiable _pairwise_check (lx, N, T)
585+
586+ function (m:: PairwiseFusion )(x:: T ) where {T}
587+ lx = length (x)
588+ N = length (m. layers)
589+ _pairwise_check (lx, N, T)
585590 applypairwisefusion (m. layers, m. connection, x)
586591end
587592
@@ -590,10 +595,12 @@ end
590595 y_symbols = [gensym () for _ in 1 : (N + 1 )]
591596 getinput (i) = T <: Tuple ? :(x[$ i]) : :x
592597 calls = [:($ (y_symbols[N + 1 ]) = $ (getinput (1 )))]
593- append! (calls,
594- [:($ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]));
595- $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 ))))
596- for i in 1 : N - 1 ])
598+ for i in 1 : N - 1
599+ push! (calls, quote
600+ $ (y_symbols[i]) = layers[$ i]($ (y_symbols[N + 1 ]))
601+ $ (y_symbols[N + 1 ]) = connection ($ (y_symbols[i]), $ (getinput (i + 1 )))
602+ end )
603+ end
597604 push! (calls, :($ (y_symbols[N]) = layers[$ N]($ (y_symbols[N + 1 ]))))
598605 push! (calls, :(return tuple ($ (Tuple (y_symbols[1 : N])... ))))
599606 return Expr (:block , calls... )
0 commit comments