Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9679805

Browse files
committed
Add activation function to Conv2D
1 parent 8b922f5 commit 9679805

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,16 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
155155
public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
156156
public var filter: Tensor<Scalar>
157157
public var bias: Tensor<Scalar>
158+
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
159+
@noDerivative public let activation: Activation
158160
@noDerivative public let strides: (Int32, Int32)
159161
@noDerivative public let padding: Padding
160162

161163
@differentiable(wrt: (self, input))
162164
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
163-
return input.convolved2D(withFilter: filter,
164-
strides: (1, strides.0, strides.1, 1),
165-
padding: padding) + bias
165+
return activation(input.convolved2D(withFilter: filter,
166+
strides: (1, strides.0, strides.1, 1),
167+
padding: padding) + bias)
166168
}
167169
}
168170

@@ -171,6 +173,7 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
171173
filterShape: (Int, Int, Int, Int),
172174
strides: (Int, Int) = (1, 1),
173175
padding: Padding,
176+
activation: @escaping Activation = identity,
174177
generator: inout G
175178
) {
176179
let filterTensorShape = TensorShape([
@@ -179,15 +182,19 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
179182
self.init(
180183
filter: Tensor(glorotUniform: filterTensorShape),
181184
bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])),
182-
strides: (Int32(strides.0), Int32(strides.1)), padding: padding)
185+
activation: activation,
186+
strides: (Int32(strides.0), Int32(strides.1)),
187+
padding: padding)
183188
}
184189

185190
init(
186191
filterShape: (Int, Int, Int, Int),
187192
strides: (Int, Int) = (1, 1),
188-
padding: Padding
193+
padding: Padding,
194+
activation: @escaping Activation = identity
189195
) {
190196
self.init(filterShape: filterShape, strides: strides, padding: padding,
197+
activation: activation,
191198
generator: &PhiloxRandomNumberGenerator.global)
192199
}
193200
}

Sources/DeepLearning/Operators.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ import TensorFlow
1717
#endif
1818

1919
// Rounds the values of a tensor to the nearest integer, element-wise.
20-
func round<Scalar: BinaryFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
20+
public func round<Scalar: BinaryFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
2121
return Raw.round(x)
2222
}
2323

2424
// Return a tensor with the same shape and contents as input.
25-
func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
25+
@differentiable
26+
public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
2627
return x
2728
}

0 commit comments

Comments
 (0)