Skip to content

Commit 01ecc07

Browse files
committed
Enable Array.subscript._modify differentiation
1 parent b9d8f0e commit 01ecc07

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

lib/Serialization/DeserializeSIL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4805,7 +4805,8 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
48054805
ArrayRef<unsigned>(parameterAndResultIndices)
48064806
.take_front(numParameterIndices));
48074807
auto numResults = originalFnType->getNumResults() +
4808-
originalFnType->getNumIndirectMutatingParameters();
4808+
originalFnType->getNumIndirectMutatingParameters() +
4809+
originalFnType->getNumYields();
48094810
auto *resultIndices =
48104811
IndexSubset::get(MF->getContext(), numResults,
48114812
ArrayRef<unsigned>(parameterAndResultIndices)

lib/Serialization/SerializeSIL.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,11 +3342,9 @@ void SILSerializer::writeSILDifferentiabilityWitness(
33423342
dw.getParameterIndices()->getCapacity() &&
33433343
"Original function parameter count should match differentiability "
33443344
"witness parameter indices capacity");
3345-
unsigned numInoutParameters = llvm::count_if(
3346-
originalFnType->getParameters(), [](SILParameterInfo paramInfo) {
3347-
return paramInfo.isIndirectMutating();
3348-
});
3349-
assert(originalFnType->getNumResults() + numInoutParameters ==
3345+
assert(originalFnType->getNumResults() +
3346+
originalFnType->getNumIndirectMutatingParameters() +
3347+
originalFnType->getNumYields() ==
33503348
dw.getResultIndices()->getCapacity() &&
33513349
"Original function result count should match differentiability "
33523350
"witness result indices capacity");

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,15 @@ where Element: AdditiveArithmetic & Differentiable {
170170

171171
@inlinable
172172
public subscript(_ index: Int) -> Element {
173-
if index < base.count {
174-
return base[index]
175-
} else {
176-
return Element.zero
173+
get {
174+
if index < base.count {
175+
return base[index]
176+
} else {
177+
return Element.zero
178+
}
179+
}
180+
_modify {
181+
yield &base[index]
177182
}
178183
}
179184
}
@@ -228,6 +233,21 @@ extension Array where Element: Differentiable {
228233
return (self[index], differential)
229234
}
230235

236+
@inlinable
237+
@derivative(of: subscript._modify)
238+
@yield_once
239+
mutating func _vjpModify(index: Int) -> (
240+
value: inout @yields Element, pullback: @yield_once (inout TangentVector) -> inout @yields Element.TangentVector
241+
) {
242+
yield &self[index]
243+
244+
@yield_once
245+
func pullback(_ v: inout TangentVector) -> inout @yields Element.TangentVector {
246+
yield &v[index]
247+
}
248+
return pullback
249+
}
250+
231251
@inlinable
232252
@derivative(of: +)
233253
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,6 @@ func testAccessorCoroutinesModify(_ x: HasCoroutineModifyAccessors) -> Float {
641641
func testBeginApplyActiveInoutArgument(array: [Float], x: Float) -> Float {
642642
var array = array
643643
// Array subscript assignment below calls `Array.subscript.modify`.
644-
// expected-error @+2 {{expression is not differentiable}}
645-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
646644
array[0] = x
647645
return array[0]
648646
}
@@ -678,8 +676,6 @@ func testBeginApplyActiveButInitiallyNonactiveInoutArgument(x: Float) -> Float {
678676
// `var array` is initially non-active.
679677
var array: [Float] = [0]
680678
// Array subscript assignment below calls `Array.subscript.modify`.
681-
// expected-error @+2 {{expression is not differentiable}}
682-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
683679
array[0] = x
684680
return array[0]
685681
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ public func f_54819<Scalar: Differentiable>(
798798
#endif
799799

800800
//===----------------------------------------------------------------------===//
801-
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
801+
// Coroutines (SIL function yields, `begin_apply`)
802802
//===----------------------------------------------------------------------===//
803803

804804
struct HasReadAccessors: Differentiable {
@@ -838,8 +838,6 @@ func testModifyAccessorCoroutines(_ x: HasModifyAccessors) -> Float {
838838
func TF_1078(array: [Float], x: Float) -> Float {
839839
var array = array
840840
// Array subscript assignment below calls `Array.subscript.modify`.
841-
// expected-error @+2 {{expression is not differentiable}}
842-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
843841
array[0] = x
844842
return array[0]
845843
}
@@ -849,8 +847,6 @@ func TF_1078(array: [Float], x: Float) -> Float {
849847
func TF_1115(_ x: Float) -> Float {
850848
var array: [Float] = [0]
851849
// Array subscript assignment below calls `Array.subscript.modify`.
852-
// expected-error @+2 {{expression is not differentiable}}
853-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
854850
array[0] = x
855851
return array[0]
856852
}

0 commit comments

Comments
 (0)