File tree Expand file tree Collapse file tree 4 files changed +34
-0
lines changed
stdlib/public/Differentiation Expand file tree Collapse file tree 4 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -110,6 +110,12 @@ where Element: Differentiable {
110110 }
111111}
112112
113+ extension Array . DifferentiableView : CustomReflectable {
114+ public var customMirror : Mirror {
115+ return base. customMirror
116+ }
117+ }
118+
113119/// Makes `Array.DifferentiableView` additive as the product space.
114120///
115121/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
Original file line number Diff line number Diff line change @@ -57,3 +57,9 @@ extension Optional: Differentiable where Wrapped: Differentiable {
5757 }
5858 }
5959}
60+
61+ extension Optional . TangentVector : CustomReflectable {
62+ public var customMirror : Mirror {
63+ return value. customMirror
64+ }
65+ }
Original file line number Diff line number Diff line change @@ -76,4 +76,15 @@ OptionalDifferentiationTests.test("Optional.TangentVector operations") {
7676 }
7777}
7878
79+ OptionalDifferentiationTests . test ( " Optional.TangentVector reflection " ) {
80+ let tan = Optional< Float> . TangentVector( 42 )
81+ let children = Array ( Mirror ( reflecting: tan) . children)
82+ expectEqual ( 1 , children. count)
83+ // We test `==` first because `as?` will flatten optionals.
84+ expectTrue ( type ( of: children [ 0 ] . value) == Float . self)
85+ if let child = expectNotNil ( children [ 0 ] . value as? Float ) {
86+ expectEqual ( 42 , child)
87+ }
88+ }
89+
7990runAllTests ( )
Original file line number Diff line number Diff line change @@ -454,4 +454,15 @@ ArrayAutoDiffTests.test("Array.DifferentiableView.move") {
454454 expectEqual ( z, [ ] )
455455}
456456
457+ ArrayAutoDiffTests . test ( " Array.DifferentiableView reflection " ) {
458+ let tan = [ Float ] . DifferentiableView ( [ 41 , 42 ] )
459+ let children = Array ( Mirror ( reflecting: tan) . children)
460+ expectEqual ( 2 , children. count)
461+ if let child1 = expectNotNil ( children [ 0 ] . value as? Float ) ,
462+ let child2 = expectNotNil ( children [ 1 ] . value as? Float ) {
463+ expectEqual ( 41 , child1)
464+ expectEqual ( 42 , child2)
465+ }
466+ }
467+
457468runAllTests ( )
You can’t perform that action at this time.
0 commit comments