Skip to content

Commit 5346e78

Browse files
committed
Allow plain tuples as layer inputs for PairwiseFusion
Also move checks to non-differentiable helper functions
1 parent cf758eb commit 5346e78

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

src/layers/basic.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838

3939
Chain(xs...) = Chain(xs)
4040
function Chain(; kw...)
41-
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
41+
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
4242
isempty(kw) && return Chain(())
4343
Chain(values(kw))
4444
end
@@ -498,12 +498,18 @@ end
498498

499499
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
500500
(m::Parallel)(xs::Tuple) = m(xs...)
501-
function (m::Parallel)(xs...)
502-
nl = length(m.layers)
503-
nx = length(xs)
504-
if nl != nx
501+
502+
function _parallel_check(layers, xs)
503+
nl = length(layers)
504+
nx = length(xs)
505+
if (nl != nx)
505506
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
506507
end
508+
end
509+
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)
510+
511+
function (m::Parallel)(xs...)
512+
_parallel_check(m.layers, xs)
507513
m.connection(map(|>, xs, Tuple(m.layers))...)
508514
end
509515

@@ -563,32 +569,36 @@ end
563569
564570
A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
565571
"""
566-
struct PairwiseFusion{F, T <: NamedTuple}
572+
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
567573
connection::F
568574
layers::T
569575
end
570576

571-
function PairwiseFusion(connection, layers...)
572-
names = ntuple(i -> Symbol("layer_$i"), length(layers))
573-
return PairwiseFusion(connection, NamedTuple{names}(layers))
577+
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
578+
function PairwiseFusion(connection; kw...)
579+
layers = NamedTuple(kw)
580+
if :layers in keys(layers) || :connection in keys(layers)
581+
throw(ArgumentError("a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`"))
582+
end
583+
isempty(layers) && return PairwiseFusion(connection, ())
584+
PairwiseFusion(connection, layers)
574585
end
575586

576-
function _pairwise_check(lx, N, T)
587+
function _pairwise_check(x, layers, T)
588+
lx = length(x)
589+
N = length(layers)
577590
if T <: Tuple && lx != N
578591
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
579592
end
580593
end
581594
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)
582595

583596
function (m::PairwiseFusion)(x::T) where {T}
584-
lx = length(x)
585-
N = length(m.layers)
586-
_pairwise_check(lx, N, T)
597+
_pairwise_check(x, m.layers, T)
587598
applypairwisefusion(m.layers, m.connection, x)
588599
end
589600

590-
@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
591-
N = length(names)
601+
@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
592602
y_symbols = [gensym() for _ in 1:(N + 1)]
593603
getinput(i) = T <: Tuple ? :(x[$i]) : :x
594604
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
@@ -602,10 +612,12 @@ end
602612
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
603613
return Expr(:block, calls...)
604614
end
615+
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)
605616

606617
@functor PairwiseFusion
607618

608619
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
620+
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
609621
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
610622
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
611623

src/layers/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
2525
for k in Base.keys(obj)
2626
_big_show(io, obj[k], indent+2, k)
2727
end
28-
elseif obj isa Parallel{<:Any, <:NamedTuple}
28+
elseif obj isa Parallel{<:Any, <:NamedTuple} || obj isa PairwiseFusion{<:Any, <:NamedTuple}
2929
_big_show(io, obj.connection, indent+2)
3030
for k in Base.keys(obj)
3131
_big_show(io, obj[k], indent+2, k)

0 commit comments

Comments
 (0)