@@ -626,7 +626,7 @@ function (g::GlobalMeanPool)(x)
626626end
627627
628628"""
629- GlobalLPNormPool(p::Float64 )
629+ GlobalLPNormPool(p::T )
630630
631631Global lp norm pooling layer.
632632
@@ -636,16 +636,16 @@ by performing lp norm pooling on the complete (w,h)-shaped feature maps.
636636See also [`LPNormPool`](@ref).
637637
638638```jldoctest
639- julia> xs = rand(Float32, 100, 100, 3, 50)
639+ julia> xs = rand(Float32, 100, 100, 3, 50);
640640
641- julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0))
641+ julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0));
642642
643643julia> m(xs) |> size
644644(1, 1, 7, 50)
645645```
646646"""
647- struct GlobalLPNormPool
648- p:: Float64
647+ struct GlobalLPNormPool{T <: Number }
648+ p:: T
649649end
650650
651651function (g:: GlobalLPNormPool )(x)
@@ -778,7 +778,7 @@ function Base.show(io::IO, m::MeanPool)
778778end
779779
780780"""
781- LPNormPool(window::NTuple, p::Float64 ; pad=0, stride=window)
781+ LPNormPool(window::NTuple, p::T ; pad=0, stride=window)
782782
783783Lp norm pooling layer, calculating p-norm distance for each window,
784784also known as LPPool in pytorch.
@@ -801,7 +801,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50);
801801julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad()))
802802Chain(
803803 Conv((5, 5), 3 => 7), # 532 parameters
804- LPNormPool((5, 5), p=2 , pad=2),
804+ LPNormPool((5, 5), 2.0 , pad=2),
805805)
806806
807807julia> m[1](xs) |> size
@@ -811,20 +811,20 @@ julia> m(xs) |> size
811811(20, 20, 7, 50)
812812
813813julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window
814- LPNormPool((5,), p=2 , pad=2, stride=3)
814+ LPNormPool((5,), 2.0 , pad=2, stride=3)
815815
816816julia> layer(rand(Float32, 100, 7, 50)) |> size
817817(34, 7, 50)
818818```
819819"""
820- struct LPNormPool{N,M}
820+ struct LPNormPool{N,M,T <: Number }
821821 k:: NTuple{N,Int}
822- p:: Float64
822+ p:: T
823823 pad:: NTuple{M,Int}
824824 stride:: NTuple{N,Int}
825825end
826826
827- function LPNormPool (k:: NTuple{N,Integer} , p:: Float64 ; pad = 0 , stride = k) where N
827+ function LPNormPool (k:: NTuple{N,Integer} , p:: T ; pad = 0 , stride = k) where {N,T}
828828 stride = expand (Val (N), stride)
829829 pad = calc_padding (LPNormPool, pad, k, 1 , stride)
830830 return LPNormPool (k, p, pad, stride)
0 commit comments