@@ -155,14 +155,16 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
155155public struct Conv2D < Scalar: TensorFlowFloatingPoint > : Layer {
156156 public var filter : Tensor < Scalar >
157157 public var bias : Tensor < Scalar >
158+ public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
159+ @noDerivative public let activation : Activation
158160 @noDerivative public let strides : ( Int32 , Int32 )
159161 @noDerivative public let padding : Padding
160162
161163 @differentiable ( wrt: ( self , input) )
162164 public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
163- return input. convolved2D ( withFilter: filter,
164- strides: ( 1 , strides. 0 , strides. 1 , 1 ) ,
165- padding: padding) + bias
165+ return activation ( input. convolved2D ( withFilter: filter,
166+ strides: ( 1 , strides. 0 , strides. 1 , 1 ) ,
167+ padding: padding) + bias)
166168 }
167169}
168170
@@ -171,6 +173,7 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
171173 filterShape: ( Int , Int , Int , Int ) ,
172174 strides: ( Int , Int ) = ( 1 , 1 ) ,
173175 padding: Padding ,
176+ activation: @escaping Activation = identity,
174177 generator: inout G
175178 ) {
176179 let filterTensorShape = TensorShape ( [
@@ -179,15 +182,19 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
179182 self . init (
180183 filter: Tensor ( glorotUniform: filterTensorShape) ,
181184 bias: Tensor ( zeros: TensorShape ( [ Int32 ( filterShape. 3 ) ] ) ) ,
182- strides: ( Int32 ( strides. 0 ) , Int32 ( strides. 1 ) ) , padding: padding)
185+ activation: activation,
186+ strides: ( Int32 ( strides. 0 ) , Int32 ( strides. 1 ) ) ,
187+ padding: padding)
183188 }
184189
185190 init (
186191 filterShape: ( Int , Int , Int , Int ) ,
187192 strides: ( Int , Int ) = ( 1 , 1 ) ,
188- padding: Padding
193+ padding: Padding ,
194+ activation: @escaping Activation = identity
189195 ) {
190196 self . init ( filterShape: filterShape, strides: strides, padding: padding,
197+ activation: activation,
191198 generator: & PhiloxRandomNumberGenerator. global)
192199 }
193200}
0 commit comments