Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3e8d86e

Browse files
authored
Reduce explicit types required when instantiating the SGD optimizer. (#28)
Previously, in order to instantiate the Optimizer, you had to call `type(of: model)` and pass that into the Optimizer constructor in order to get type inference to pick the right type for `Model`. This could be a little confusing for new users. This commit proposes an alternate way to write this: ```swift let optimizer = SGD(for: model, learningRate: 0.01, scalarType: Float.self) ``` The above formulation is clear and readable. It avoids any unnecessary typing of generic argument types. By annotating the model parameter as `__shared`, we ensure that we don't pay for the cost of a model copy (which could eventually be very expensive).
1 parent 45c2bcd commit 3e8d86e

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

Sources/DeepLearning/Optimizer.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
3535
public let decay: Scalar
3636

3737
public init(
38+
for _: __shared Model,
3839
learningRate: Scalar = 1e-3,
3940
beta1: Scalar = 0.9,
4041
beta2: Scalar = 0.999,
4142
epsilon: Scalar = 1e-8,
4243
decay: Scalar = 0,
43-
modelType: Model.Type = Model.self,
44-
scalarType: Scalar.Type = Scalar.self
44+
scalarType: Scalar.Type
4545
) {
4646
precondition(learningRate >= 0, "Learning rate must be non-negative")
4747
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
@@ -84,12 +84,12 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
8484
public let decay: Scalar
8585

8686
public init(
87+
for _: __shared Model,
8788
learningRate: Scalar = 0.001,
8889
rho: Scalar = 0.9,
8990
epsilon: Scalar = 1e-8,
9091
decay: Scalar = 0,
91-
modelType: Model.Type = Model.self,
92-
scalarType: Scalar.Type = Scalar.self
92+
scalarType: Scalar.Type
9393
) {
9494
precondition(learningRate >= 0, "Learning rate must be non-negative")
9595
precondition(rho >= 0, "Rho must be non-negative")
@@ -125,12 +125,12 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
125125
public let nesterov: Bool
126126

127127
public init(
128+
for _: __shared Model,
128129
learningRate: Scalar = 0.01,
129130
momentum: Scalar = 0,
130131
decay: Scalar = 0,
131132
nesterov: Bool = false,
132-
modelType: Model.Type = Model.self,
133-
scalarType: Scalar.Type = Scalar.self
133+
scalarType: Scalar.Type
134134
) {
135135
precondition(learningRate >= 0, "Learning rate must be non-negative")
136136
precondition(momentum >= 0, "Momentum must be non-negative")
@@ -171,7 +171,7 @@ public class RiemannSGD<Model: Layer, Scalar: FloatingPoint>: Optimizer
171171
public init(
172172
learningRate: Scalar,
173173
modelType: Model.Type = Model.self,
174-
scalarType: Scalar.Type = Scalar.self
174+
scalarType: Scalar.Type
175175
) {
176176
self.learningRate = learningRate
177177
}

Tests/DeepLearningTests/SequentialTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ final class SequentialTests: XCTestCase {
2727
}
2828
}
2929
var model = Model()
30-
let optimizer = SGD(learningRate: 0.02, modelType: type(of: model), scalarType: Float.self)
30+
let optimizer = SGD(for: model, learningRate: 0.02, scalarType: Float.self)
3131
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
3232
let y: Tensor<Float> = [0, 1, 1, 0]
3333
let context = Context(learningPhase: .training)

Tests/DeepLearningTests/TrivialModelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ final class TrivialModelTests: XCTestCase {
4040
return l2.applied(to: h1, in: context)
4141
}
4242
}
43-
let optimizer = SGD<Classifier, Float>(learningRate: 0.02)
4443
var classifier = Classifier(hiddenSize: 4)
44+
let optimizer = SGD(for: classifier, learningRate: 0.02, scalarType: Float.self)
4545
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
4646
let y: Tensor<Float> = [[0], [1], [1], [0]]
4747

0 commit comments

Comments
 (0)