Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8947c1c

Browse files
committed
Fix test failure and simplify code.
* Bump up the number of epochs to `2000`. * Use `Layer.inferring(from:)` instead of `Layer.applied(to:)`.
1 parent 8c228db commit 8947c1c

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

Tests/DeepLearningTests/TrivialModelTests.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,14 @@ final class TrivialModelTests: XCTestCase {
4646
let y: Tensor<Float> = [[0], [1], [1], [0]]
4747

4848
let trainingContext = Context(learningPhase: .training)
49-
for _ in 0..<1000 {
50-
let (_, 𝛁model) = classifier.valueWithGradient { classifier -> Tensor<Float> in
49+
for _ in 0..<2000 {
50+
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
5151
let ŷ = classifier.applied(to: x, in: trainingContext)
5252
return meanSquaredError(predicted: ŷ, expected: y)
5353
}
5454
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
5555
}
56-
57-
let inferenceContext = Context(learningPhase: .inference)
58-
let ŷ = classifier.applied(to: x, in: inferenceContext)
56+
let ŷ = classifier.inferring(from: x)
5957
XCTAssertEqual(round(ŷ), y)
6058
}
6159

0 commit comments

Comments
 (0)