Skip to content

Commit 0fd46d1

Browse files
committed
Enable Array.subscript._modify differentiation
1 parent f6eb061 commit 0fd46d1

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

lib/Serialization/DeserializeSIL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5135,7 +5135,8 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
51355135
ArrayRef<unsigned>(parameterAndResultIndices)
51365136
.take_front(numParameterIndices));
51375137
auto numResults = originalFnType->getNumResults() +
5138-
originalFnType->getNumIndirectMutatingParameters();
5138+
originalFnType->getNumIndirectMutatingParameters() +
5139+
originalFnType->getNumYields();
51395140
auto *resultIndices =
51405141
IndexSubset::get(MF->getContext(), numResults,
51415142
ArrayRef<unsigned>(parameterAndResultIndices)

lib/Serialization/SerializeSIL.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,11 +3547,9 @@ void SILSerializer::writeSILDifferentiabilityWitness(
35473547
dw.getParameterIndices()->getCapacity() &&
35483548
"Original function parameter count should match differentiability "
35493549
"witness parameter indices capacity");
3550-
unsigned numInoutParameters = llvm::count_if(
3551-
originalFnType->getParameters(), [](SILParameterInfo paramInfo) {
3552-
return paramInfo.isIndirectMutating();
3553-
});
3554-
assert(originalFnType->getNumResults() + numInoutParameters ==
3550+
assert(originalFnType->getNumResults() +
3551+
originalFnType->getNumIndirectMutatingParameters() +
3552+
originalFnType->getNumYields() ==
35553553
dw.getResultIndices()->getCapacity() &&
35563554
"Original function result count should match differentiability "
35573555
"witness result indices capacity");

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,15 @@ where Element: AdditiveArithmetic & Differentiable {
170170

171171
@inlinable
172172
public subscript(_ index: Int) -> Element {
173-
if index < base.count {
174-
return base[index]
175-
} else {
176-
return Element.zero
173+
get {
174+
if index < base.count {
175+
return base[index]
176+
} else {
177+
return Element.zero
178+
}
179+
}
180+
_modify {
181+
yield &base[index]
177182
}
178183
}
179184
}
@@ -228,6 +233,21 @@ extension Array where Element: Differentiable {
228233
return (self[index], differential)
229234
}
230235

236+
@inlinable
237+
@derivative(of: subscript._modify)
238+
@yield_once
239+
mutating func _vjpModify(index: Int) -> (
240+
value: inout @yields Element, pullback: @yield_once (inout TangentVector) -> inout @yields Element.TangentVector
241+
) {
242+
yield &self[index]
243+
244+
@yield_once
245+
func pullback(_ v: inout TangentVector) -> inout @yields Element.TangentVector {
246+
yield &v[index]
247+
}
248+
return pullback
249+
}
250+
231251
@inlinable
232252
@derivative(of: +)
233253
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,6 @@ func testAccessorCoroutinesModify(_ x: HasCoroutineModifyAccessors) -> Float {
641641
func testBeginApplyActiveInoutArgument(array: [Float], x: Float) -> Float {
642642
var array = array
643643
// Array subscript assignment below calls `Array.subscript.modify`.
644-
// expected-error @+2 {{expression is not differentiable}}
645-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
646644
array[0] = x
647645
return array[0]
648646
}
@@ -678,8 +676,6 @@ func testBeginApplyActiveButInitiallyNonactiveInoutArgument(x: Float) -> Float {
678676
// `var array` is initially non-active.
679677
var array: [Float] = [0]
680678
// Array subscript assignment below calls `Array.subscript.modify`.
681-
// expected-error @+2 {{expression is not differentiable}}
682-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
683679
array[0] = x
684680
return array[0]
685681
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ public func fragileDifferentiable(_ x: Float) -> Float {
779779
}
780780

781781
//===----------------------------------------------------------------------===//
782-
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
782+
// Coroutines (SIL function yields, `begin_apply`)
783783
//===----------------------------------------------------------------------===//
784784

785785
struct HasReadAccessors: Differentiable {
@@ -819,8 +819,6 @@ func testModifyAccessorCoroutines(_ x: HasModifyAccessors) -> Float {
819819
func TF_1078(array: [Float], x: Float) -> Float {
820820
var array = array
821821
// Array subscript assignment below calls `Array.subscript.modify`.
822-
// expected-error @+2 {{expression is not differentiable}}
823-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
824822
array[0] = x
825823
return array[0]
826824
}
@@ -830,8 +828,6 @@ func TF_1078(array: [Float], x: Float) -> Float {
830828
func TF_1115(_ x: Float) -> Float {
831829
var array: [Float] = [0]
832830
// Array subscript assignment below calls `Array.subscript.modify`.
833-
// expected-error @+2 {{expression is not differentiable}}
834-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
835831
array[0] = x
836832
return array[0]
837833
}

0 commit comments

Comments
 (0)