@@ -42,6 +42,14 @@ where Element: Differentiable {
4242 return ( base, { $0 } )
4343 }
4444
45+ @usableFromInline
46+ @derivative ( of: base)
47+ func _jvpBase( ) -> (
48+ value: [ Element ] , differential: ( Array < Element > . TangentVector ) -> TangentVector
49+ ) {
50+ return ( base, { $0 } )
51+ }
52+
4553 /// Creates a differentiable view of the given array.
4654 public init ( _ base: [ Element ] ) { self . _base = base }
4755
@@ -53,6 +61,14 @@ where Element: Differentiable {
5361 return ( Array . DifferentiableView ( base) , { $0 } )
5462 }
5563
64+ @usableFromInline
65+ @derivative ( of: init ( _: ) )
66+ static func _jvpInit( _ base: [ Element ] ) -> (
67+ value: Array . DifferentiableView , differential: ( TangentVector ) -> TangentVector
68+ ) {
69+ return ( Array . DifferentiableView ( base) , { $0 } )
70+ }
71+
5672 public typealias TangentVector =
5773 Array < Element . TangentVector > . DifferentiableView
5874
@@ -191,6 +207,17 @@ extension Array where Element: Differentiable {
191207 return ( self [ index] , pullback)
192208 }
193209
210+ @usableFromInline
211+ @derivative ( of: subscript)
212+ func _jvpSubscript( index: Int ) -> (
213+ value: Element , differential: ( TangentVector ) -> Element . TangentVector
214+ ) {
215+ func differential( _ v: TangentVector ) -> Element . TangentVector {
216+ return v [ index]
217+ }
218+ return ( self [ index] , differential)
219+ }
220+
194221 @usableFromInline
195222 @derivative ( of: + )
196223 static func _vjpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
@@ -210,8 +237,26 @@ extension Array where Element: Differentiable {
210237 }
211238 return ( lhs + rhs, pullback)
212239 }
240+
241+ @usableFromInline
242+ @derivative ( of: + )
243+ static func _jvpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
244+ value: Self ,
245+ differential: ( TangentVector , TangentVector ) -> TangentVector
246+ ) {
247+ func differential( _ l: TangentVector , _ r: TangentVector ) -> TangentVector {
248+ precondition (
249+ l. base. count == lhs. count && r. base. count == rhs. count, """
250+ Tangent vectors with invalid count; expected to equal the \
251+ operand counts \( lhs. count) and \( rhs. count)
252+ """ )
253+ return . init( l. base + r. base)
254+ }
255+ return ( lhs + rhs, differential)
256+ }
213257}
214258
259+
215260extension Array where Element: Differentiable {
216261 @usableFromInline
217262 @derivative ( of: append)
@@ -277,6 +322,17 @@ extension Array where Element: Differentiable {
277322 }
278323 )
279324 }
325+
326+ @usableFromInline
327+ @derivative ( of: init ( repeating: count: ) )
328+ static func _jvpInit( repeating repeatedValue: Element , count: Int ) -> (
329+ value: Self , differential: ( Element . TangentVector ) -> TangentVector
330+ ) {
331+ (
332+ value: Self ( repeating: repeatedValue, count: count) ,
333+ differential: { v in TangentVector ( . init( repeating: v, count: count) ) }
334+ )
335+ }
280336}
281337
282338//===----------------------------------------------------------------------===//
@@ -312,6 +368,27 @@ extension Array where Element: Differentiable {
312368 }
313369 return ( value: values, pullback: pullback)
314370 }
371+
372+ @inlinable
373+ @derivative ( of: differentiableMap)
374+ internal func _jvpDifferentiableMap< Result: Differentiable > (
375+ _ body: @differentiable ( Element ) -> Result
376+ ) -> (
377+ value: [ Result ] ,
378+ differential: ( Array . TangentVector ) -> Array < Result > . TangentVector
379+ ) {
380+ var values : [ Result ] = [ ]
381+ var differentials : [ ( Element . TangentVector ) -> Result . TangentVector ] = [ ]
382+ for x in self {
383+ let ( y, df) = valueWithDifferential ( at: x, in: body)
384+ values. append ( y)
385+ differentials. append ( df)
386+ }
387+ func differential( _ tans: Array . TangentVector ) -> Array < Result > . TangentVector {
388+ . init( zip ( tans. base, differentials) . map { tan, df in df ( tan) } )
389+ }
390+ return ( value: values, differential: differential)
391+ }
315392}
316393
317394extension Array where Element: Differentiable {
@@ -361,4 +438,33 @@ extension Array where Element: Differentiable {
361438 }
362439 )
363440 }
441+
442+ @inlinable
443+ @derivative ( of: differentiableReduce, wrt: ( self , initialResult) )
444+ func _jvpDifferentiableReduce< Result: Differentiable > (
445+ _ initialResult: Result ,
446+ _ nextPartialResult: @differentiable ( Result , Element ) -> Result
447+ ) -> ( value: Result ,
448+ differential: ( Array . TangentVector , Result . TangentVector )
449+ -> Result . TangentVector ) {
450+ var differentials :
451+ [ ( Result . TangentVector , Element . TangentVector ) -> Result . TangentVector ]
452+ = [ ]
453+ let count = self . count
454+ differentials. reserveCapacity ( count)
455+ var result = initialResult
456+ for element in self {
457+ let ( y, df) =
458+ valueWithDifferential ( at: result, element, in: nextPartialResult)
459+ result = y
460+ differentials. append ( df)
461+ }
462+ return ( value: result, differential: { dSelf, dInitial in
463+ var dResult = dInitial
464+ for (dElement, df) in zip ( dSelf. base, differentials) {
465+ dResult = df ( dResult, dElement)
466+ }
467+ return dResult
468+ } )
469+ }
364470}
0 commit comments