Skip to content

Commit b4fb574

Browse files
authored
[AutoDiff] Make sure rhs is negated in all cases when subtracting two Array.TangentVectors (#84731)
`Array.TangentVector` conformance to `AdditiveArithmetic` was incorrect as the returned values weren't negated if the lhs was an empty vector (considered to be a zero tangentvector)
1 parent 1d8536e commit b4fb574

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ where Element: AdditiveArithmetic & Differentiable {
157157
rhs: Array.DifferentiableView
158158
) -> Array.DifferentiableView {
159159
if lhs.base.count == 0 {
160-
return rhs
160+
return Array.DifferentiableView(rhs.base.map { .zero - $0 })
161161
}
162162
if rhs.base.count == 0 {
163163
return lhs
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
import StdlibUnittest
6+
7+
var ArrayDifferentiationTests = TestSuite("ArrayDifferentiation")
8+
9+
ArrayDifferentiationTests.test("Array.DifferentiableView+") {
10+
let zero1: Array<Float>.DifferentiableView = [0, 0, 0]
11+
let zero2: Array<Float>.DifferentiableView = .zero
12+
let a: Array<Float>.DifferentiableView = [1, 2, 3]
13+
14+
expectEqual(a + a, [2, 4, 6])
15+
16+
expectEqual(a + zero1, [1, 2, 3])
17+
expectEqual(zero1 + a, [1, 2, 3])
18+
19+
expectEqual(a + zero2, [1, 2, 3])
20+
expectEqual(zero2 + a, [1, 2, 3])
21+
}
22+
23+
ArrayDifferentiationTests.test("Array.DifferentiableView-") {
24+
let zero1: Array<Float>.DifferentiableView = [0, 0, 0]
25+
let zero2: Array<Float>.DifferentiableView = .zero
26+
let a: Array<Float>.DifferentiableView = [1, 2, 3]
27+
28+
expectEqual(a - a, [0, 0, 0])
29+
30+
expectEqual(a - zero1, [1, 2, 3])
31+
expectEqual(zero1 - a, [-1, -2, -3])
32+
33+
expectEqual(a - zero2, [1, 2, 3])
34+
expectEqual(zero2 - a, [-1, -2, -3])
35+
}
36+
37+
runAllTests()

0 commit comments

Comments
 (0)