Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b2f8ec3

Browse files
committed
Merge branch 'master' of github.com:tensorflow/swift-apis
2 parents 8947c1c + ee31b7f commit b2f8ec3

File tree

4 files changed

+148
-2
lines changed

4 files changed

+148
-2
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,80 @@ public extension Layer {
102102
}
103103
}
104104

105+
/// Adds helpers for standard feed-forward, sequential models.
106+
public extension Differentiable {
107+
@differentiable(wrt: (self, l1, l2))
108+
func sequenced<L1: Layer, L2: Layer>(
109+
in context: Context, through l1: L1, _ l2: L2)
110+
-> L2.Output
111+
where L1.Input == Self,
112+
L1.Output == L2.Input {
113+
let o1 = l1.applied(to: self, in: context)
114+
return l2.applied(to: o1, in: context)
115+
}
116+
117+
@differentiable(wrt: (self, l1, l2, l3))
118+
func sequenced<L1: Layer, L2: Layer, L3: Layer>(
119+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3)
120+
-> L3.Output
121+
where L1.Input == Self,
122+
L1.Output == L2.Input,
123+
L2.Output == L3.Input {
124+
let o1 = l1.applied(to: self, in: context)
125+
let o2 = l2.applied(to: o1, in: context)
126+
return l3.applied(to: o2, in: context)
127+
}
128+
129+
@differentiable(wrt: (self, l1, l2, l3, l4))
130+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer>(
131+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4)
132+
-> L4.Output
133+
where L1.Input == Self,
134+
L1.Output == L2.Input,
135+
L2.Output == L3.Input,
136+
L3.Output == L4.Input {
137+
let o1 = l1.applied(to: self, in: context)
138+
let o2 = l2.applied(to: o1, in: context)
139+
let o3 = l3.applied(to: o2, in: context)
140+
return l4.applied(to: o3, in: context)
141+
}
142+
143+
@differentiable(wrt: (self, l1, l2, l3, l4, l5))
144+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer>(
145+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5)
146+
-> L5.Output
147+
where L1.Input == Self,
148+
L1.Output == L2.Input,
149+
L2.Output == L3.Input,
150+
L3.Output == L4.Input,
151+
L4.Output == L5.Input {
152+
let o1 = l1.applied(to: self, in: context)
153+
let o2 = l2.applied(to: o1, in: context)
154+
let o3 = l3.applied(to: o2, in: context)
155+
let o4 = l4.applied(to: o3, in: context)
156+
return l5.applied(to: o4, in: context)
157+
}
158+
159+
@differentiable(wrt: (self, l1, l2, l3, l4, l5, l6))
160+
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer>(
161+
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6)
162+
-> L6.Output
163+
where L1.Input == Self,
164+
L1.Output == L2.Input,
165+
L2.Output == L3.Input,
166+
L3.Output == L4.Input,
167+
L4.Output == L5.Input,
168+
L5.Output == L6.Input {
169+
let o1 = l1.applied(to: self, in: context)
170+
let o2 = l2.applied(to: o1, in: context)
171+
let o3 = l3.applied(to: o2, in: context)
172+
let o4 = l4.applied(to: o3, in: context)
173+
let o5 = l5.applied(to: o4, in: context)
174+
return l6.applied(to: o5, in: context)
175+
}
176+
}
177+
178+
105179
/// A mutable, shareable, owning reference to a tensor.
106180
public final class Parameter<Scalar: TensorFlowScalar> {
107181
public var value: Tensor<Scalar>
@@ -137,7 +211,7 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
137211
init<G: RandomNumberGenerator>(
138212
inputSize: Int,
139213
outputSize: Int,
140-
activation: @escaping Activation,
214+
activation: @escaping Activation = identity,
141215
generator: inout G
142216
) {
143217
self.init(weight: Tensor(glorotUniform: [Int32(inputSize), Int32(outputSize)],
@@ -146,7 +220,7 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
146220
activation: activation)
147221
}
148222

149-
init(inputSize: Int, outputSize: Int, activation: @escaping Activation) {
223+
init(inputSize: Int, outputSize: Int, activation: @escaping Activation = identity) {
150224
self.init(inputSize: inputSize, outputSize: outputSize, activation: activation,
151225
generator: &PhiloxRandomNumberGenerator.global)
152226
}

Sources/DeepLearning/Loss.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,40 @@
1616
import TensorFlow
1717
#endif
1818

19+
/// Computes the mean squared error between logits and labels.
20+
///
21+
/// - Parameters:
22+
/// - logits: One-hot encoded outputs from a neural network.
23+
/// - labels: One-hot encoded values that correspond to the correct output.
1924
@differentiable
2025
public func meanSquaredError<Scalar: TensorFlowFloatingPoint>(
2126
predicted: Tensor<Scalar>, expected: Tensor<Scalar>
2227
) -> Tensor<Scalar> {
2328
return (expected - predicted).squared().mean()
2429
}
2530

31+
/// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
32+
///
33+
/// - Parameters:
34+
/// - logits: One-hot encoded outputs from a neural network.
35+
/// - labels: One-hot encoded values that correspond to the correct output.
2636
@differentiable
2737
public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
2838
logits: Tensor<Scalar>, labels: Tensor<Scalar>
2939
) -> Tensor<Scalar> {
3040
return -(labels * logSoftmax(logits)).mean(alongAxes: 0).sum()
3141
}
42+
43+
/// Computes the sigmoid cross entropy (binary cross entropy) between logits and labels.
44+
///
45+
/// - Parameters:
46+
/// - logits: Single continuous values from `0` to `1`.
47+
/// - labels: Integer values that correspond to the correct output.
48+
@differentiable
49+
public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
50+
logits: Tensor<Scalar>, labels: Tensor<Scalar>
51+
) -> Tensor<Scalar> {
52+
let loss = labels * log(logits) +
53+
(Tensor<Scalar>(1) - labels) * log(Tensor<Scalar>(1) - logits)
54+
return -loss.mean(alongAxes: 0).sum()
55+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import DeepLearning
17+
18+
final class SequentialTests: XCTestCase {
19+
func testSequential() {
20+
struct Model: Layer {
21+
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
22+
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu)
23+
24+
@differentiable(wrt: (self, input))
25+
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
26+
return input.sequenced(in: context, through: dense1, dense2)
27+
}
28+
}
29+
var model = Model()
30+
let optimizer = SGD(learningRate: 0.02, modelType: type(of: model), scalarType: Float.self)
31+
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
32+
let y: Tensor<Float> = [0, 1, 1, 0]
33+
let context = Context(learningPhase: .training)
34+
for _ in 0..<1000 {
35+
let 𝛁model = model.gradient { model -> Tensor<Float> in
36+
let ŷ = model.applied(to: x, in: context)
37+
return meanSquaredError(predicted: ŷ, expected: y)
38+
}
39+
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
40+
}
41+
print(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]))
42+
}
43+
44+
static var allTests = [
45+
("testSequential", testSequential)
46+
]
47+
}

Tests/DeepLearningTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public func allTests() -> [XCTestCaseEntry] {
1919
return [
2020
testCase(PRNGTests.allTests),
2121
testCase(TrivialModelTests.allTests),
22+
testCase(SequentialTests.allTests),
2223
]
2324
}
2425
#endif

0 commit comments

Comments
 (0)