Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b9a05df

Browse files
authored
Add default padding values. (#21)
In the interest of progressive disclosure of complexity, and good default values, set the padding on the layer types to be `.valid`. This follows the Keras convention as well (https://keras.io/layers/convolutional/).
1 parent 523781c commit b9a05df

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
172172
init<G: RandomNumberGenerator>(
173173
filterShape: (Int, Int, Int, Int),
174174
strides: (Int, Int) = (1, 1),
175-
padding: Padding,
175+
padding: Padding = .valid,
176176
activation: @escaping Activation = identity,
177177
generator: inout G
178178
) {
@@ -190,7 +190,7 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
190190
init(
191191
filterShape: (Int, Int, Int, Int),
192192
strides: (Int, Int) = (1, 1),
193-
padding: Padding,
193+
padding: Padding = .valid,
194194
activation: @escaping Activation = identity
195195
) {
196196
self.init(filterShape: filterShape, strides: strides, padding: padding,
@@ -284,7 +284,7 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
284284
/// The padding algorithm for pooling.
285285
@noDerivative let padding: Padding
286286

287-
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding) {
287+
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) {
288288
self.poolSize = (1, Int32(poolSize.0), Int32(poolSize.1), 1)
289289
self.strides = (1, Int32(strides.0), Int32(strides.1), 1)
290290
self.padding = padding
@@ -307,7 +307,7 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
307307
/// The padding algorithm for pooling.
308308
@noDerivative let padding: Padding
309309

310-
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding) {
310+
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) {
311311
self.poolSize = (1, Int32(poolSize.0), Int32(poolSize.1), 1)
312312
self.strides = (1, Int32(strides.0), Int32(strides.1), 1)
313313
self.padding = padding

0 commit comments

Comments
 (0)