Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9b3b609

Browse files
authored
Add a helper function for sequential models. (#20)
Many deep learning models are composed of sequential layers stacked one on top of each other. It can be relatively tedious to write out the explicit `applied(to:)` function because it's fairly repetitive and the underlying intent is relatively obscured. (It can be especially bothersome because it's the 2nd (or 3rd) time you're writing out all the layers. (The first time is to declare all the instance variables, and the second time (if necessary) is in the initializer.) Fortunately, with helper functions, we can make everything both type safe as well as convenient and easily expressible & readable! This commit adds a family of `sequenced(in:through:)` functions that take in a context, an input, and a variable number of layers. It chains through the output of one layer into the input of the next. This API approach has a number of advantages: 1. It avoids introducing new symbolic operators, which can be very confusing to new users. 2. It works with today's AutoDiff implementation. (Yay!) 3. It is very readable and clean. 4. It avoids users "getting stuck". Concretely, if someone implemented a model using my previously proposed `>>>` operator, if they wanted to add a residual (or skip) connection, they would have to basically re-write their whole model using a struct, etc. With this API structure, only "local" changes are required. (e.g. If only one skip-connection is required, they can split the sequential chain into two pieces.) Downsides of this approach: 1. It doesn't DRY-out the types required to define a model. (I have some thoughts here, but there isn't enough room in this margin^H^H^H^H^H^Hcommit message.) 2. We should think hard about how things should look when we have loops. 3. We should switch to gyb to generate the code for all the different arities.
1 parent 6add73c commit 9b3b609

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,80 @@ public extension Layer {
102102
}
103103
}
104104

105+
/// Adds helpers for standard feed-forward, sequential models.
106+
public extension Differentiable {
107+
@differentiable(wrt: (self, l1, l2))
108+
func sequenced<L1: Layer, L2: Layer>(
109+
in context: Context, through l1: L1, _ l2: L2)
110+
-> L2.Output
111+
where L1.Input == Self,
112+
L1.Output == L2.Input {
113+
let o1 = l1.applied(to: self, in: context)
114+
return l2.applied(to: o1, in: context)
115+
}
116+
117+
@differentiable(wrt: (self, l1, l2, l3))
118+
func sequenced<L1: Layer, L2: Layer, L3: Layer>(
119+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3)
120+
-> L3.Output
121+
where L1.Input == Self,
122+
L1.Output == L2.Input,
123+
L2.Output == L3.Input {
124+
let o1 = l1.applied(to: self, in: context)
125+
let o2 = l2.applied(to: o1, in: context)
126+
return l3.applied(to: o2, in: context)
127+
}
128+
129+
@differentiable(wrt: (self, l1, l2, l3, l4))
130+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer>(
131+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4)
132+
-> L4.Output
133+
where L1.Input == Self,
134+
L1.Output == L2.Input,
135+
L2.Output == L3.Input,
136+
L3.Output == L4.Input {
137+
let o1 = l1.applied(to: self, in: context)
138+
let o2 = l2.applied(to: o1, in: context)
139+
let o3 = l3.applied(to: o2, in: context)
140+
return l4.applied(to: o3, in: context)
141+
}
142+
143+
@differentiable(wrt: (self, l1, l2, l3, l4, l5))
144+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer>(
145+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5)
146+
-> L5.Output
147+
where L1.Input == Self,
148+
L1.Output == L2.Input,
149+
L2.Output == L3.Input,
150+
L3.Output == L4.Input,
151+
L4.Output == L5.Input {
152+
let o1 = l1.applied(to: self, in: context)
153+
let o2 = l2.applied(to: o1, in: context)
154+
let o3 = l3.applied(to: o2, in: context)
155+
let o4 = l4.applied(to: o3, in: context)
156+
return l5.applied(to: o4, in: context)
157+
}
158+
159+
@differentiable(wrt: (self, l1, l2, l3, l4, l5, l6))
160+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer>(
161+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6)
162+
-> L6.Output
163+
where L1.Input == Self,
164+
L1.Output == L2.Input,
165+
L2.Output == L3.Input,
166+
L3.Output == L4.Input,
167+
L4.Output == L5.Input,
168+
L5.Output == L6.Input {
169+
let o1 = l1.applied(to: self, in: context)
170+
let o2 = l2.applied(to: o1, in: context)
171+
let o3 = l3.applied(to: o2, in: context)
172+
let o4 = l4.applied(to: o3, in: context)
173+
let o5 = l5.applied(to: o4, in: context)
174+
return l6.applied(to: o5, in: context)
175+
}
176+
}
177+
178+
105179
/// A mutable, shareable, owning reference to a tensor.
106180
public final class Parameter<Scalar: TensorFlowScalar> {
107181
public var value: Tensor<Scalar>
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import DeepLearning
17+
18+
final class SequentialTests: XCTestCase {
19+
func testSequential() {
20+
struct Model: Layer {
21+
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
22+
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu)
23+
24+
@differentiable(wrt: (self, input))
25+
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
26+
return input.sequenced(in: context, through: dense1, dense2)
27+
}
28+
}
29+
var model = Model()
30+
let optimizer = SGD(learningRate: 0.02, modelType: type(of: model), scalarType: Float.self)
31+
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
32+
let y: Tensor<Float> = [0, 1, 1, 0]
33+
let context = Context(learningPhase: .training)
34+
for _ in 0..<1000 {
35+
let 𝛁model = model.gradient { model -> Tensor<Float> in
36+
let ŷ = model.applied(to: x, in: context)
37+
return meanSquaredError(predicted: ŷ, expected: y)
38+
}
39+
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
40+
}
41+
print(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]))
42+
}
43+
44+
static var allTests = [
45+
("testSequential", testSequential)
46+
]
47+
}

Tests/DeepLearningTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public func allTests() -> [XCTestCaseEntry] {
1919
return [
2020
testCase(PRNGTests.allTests),
2121
testCase(TrivialModelTests.allTests),
22+
testCase(SequentialTests.allTests),
2223
]
2324
}
2425
#endif

0 commit comments

Comments
 (0)