@@ -26,111 +26,115 @@ import Swift
2626#error("Unsupported platform")
2727#endif
2828
29- @usableFromInline
29+ % for T in [ 'T', 'Double'] : # Prevents name collisions with system math library
30+ % generic_signature = '< T: FloatingPoint & Differentiable> ' if T == 'T' else ''
31+ % constraint = 'where T == T. TangentVector' if T == 'T' else ''
32+ @inlinable
3033@derivative( of: fma)
31- func _jvpFma< T : FloatingPoint & Differentiable > (
32- _ x: T ,
33- _ y: T ,
34- _ z: T
35- ) -> ( value: T , differential: ( T , T , T ) -> T ) where T == T . TangentVector {
34+ func _jvpFma$ { generic_signature } (
35+ _ x: $ { T } ,
36+ _ y: $ { T } ,
37+ _ z: $ { T }
38+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
3639 return ( fma ( x, y, z) , { ( dx, dy, dz) in dx * y + dy * x + dz } )
3740}
3841
39- @usableFromInline
42+ @inlinable
4043@derivative( of: fma)
41- func _vjpFma< T : FloatingPoint & Differentiable > (
42- _ x: T ,
43- _ y: T ,
44- _ z: T
45- ) -> ( value: T , pullback: ( T ) -> ( T , T , T ) ) where T == T . TangentVector {
44+ func _vjpFma$ { generic_signature } (
45+ _ x: $ { T } ,
46+ _ y: $ { T } ,
47+ _ z: $ { T }
48+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } , $ { T } ) ) $ { constraint } {
4649 return ( fma ( x, y, z) , { v in ( v * y, v * x, v) } )
4750}
4851
49- @usableFromInline
52+ @inlinable
5053@derivative( of: remainder)
51- func _jvpRemainder< T : FloatingPoint & Differentiable > (
52- _ x: T ,
53- _ y: T
54- ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
54+ func _jvpRemainder$ { generic_signature } (
55+ _ x: $ { T } ,
56+ _ y: $ { T }
57+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
5558 fatalError ( """
5659 Unimplemented JVP for 'remainder(_:)'. \
5760 https://bugs.swift.org/browse/TF-1108 tracks this issue
5861 """ )
5962}
6063
61- @usableFromInline
64+ @inlinable
6265@derivative( of: remainder)
63- func _vjpRemainder< T : FloatingPoint & Differentiable > (
64- _ x: T ,
65- _ y: T
66- ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
66+ func _vjpRemainder$ { generic_signature } (
67+ _ x: $ { T } ,
68+ _ y: $ { T }
69+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } ) ) $ { constraint } {
6770 return ( remainder ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . toNearestOrEven) ) ) } )
6871}
6972
70- @usableFromInline
73+ @inlinable
7174@derivative( of: fmod)
72- func _jvpFmod< T : FloatingPoint & Differentiable > (
73- _ x: T ,
74- _ y: T
75- ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
75+ func _jvpFmod$ { generic_signature } (
76+ _ x: $ { T } ,
77+ _ y: $ { T }
78+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
7679 fatalError ( """
7780 Unimplemented JVP for 'fmod(_:)'. \
7881 https://bugs.swift.org/browse/TF-1108 tracks this issue
7982 """ )
8083}
8184
82- @usableFromInline
85+ @inlinable
8386@derivative( of: fmod)
84- func _vjpFmod< T : FloatingPoint & Differentiable > (
85- _ x: T ,
86- _ y: T
87- ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
87+ func _vjpFmod$ { generic_signature } (
88+ _ x: $ { T } ,
89+ _ y: $ { T }
90+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } ) ) $ { constraint } {
8891 return ( fmod ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . towardZero) ) ) } )
8992}
9093
91- % for derivative_kind in [ 'jvp', 'vjp'] :
92- % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
93- @usableFromInline
94+ % for derivative_kind in [ 'jvp', 'vjp'] :
95+ % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
96+ @inlinable
9497@derivative( of: sqrt)
95- func _${ derivative_kind} Sqrt< T : FloatingPoint & Differentiable > (
96- _ x: T
97- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
98+ func _${ derivative_kind} Sqrt$ { generic_signature } (
99+ _ x: $ { T }
100+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
98101 let value = sqrt ( x)
99102 return ( value, { v in v / ( 2 * value) } )
100103}
101104
102- @usableFromInline
105+ @inlinable
103106@derivative( of: ceil)
104- func _${ derivative_kind} Ceil< T : FloatingPoint & Differentiable > (
105- _ x: T
106- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
107+ func _${ derivative_kind} Ceil$ { generic_signature } (
108+ _ x: $ { T }
109+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
107110 return ( ceil ( x) , { v in 0 } )
108111}
109112
110- @usableFromInline
113+ @inlinable
111114@derivative( of: floor)
112- func _${ derivative_kind} Floor< T : FloatingPoint & Differentiable > (
113- _ x: T
114- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
115+ func _${ derivative_kind} Floor$ { generic_signature } (
116+ _ x: $ { T }
117+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
115118 return ( floor ( x) , { v in 0 } )
116119}
117120
118- @usableFromInline
121+ @inlinable
119122@derivative( of: round)
120- func _${ derivative_kind} Round< T : FloatingPoint & Differentiable > (
121- _ x: T
122- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
123+ func _${ derivative_kind} Round$ { generic_signature } (
124+ _ x: $ { T }
125+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
123126 return ( round ( x) , { v in 0 } )
124127}
125128
126- @usableFromInline
129+ @inlinable
127130@derivative( of: trunc)
128- func _${ derivative_kind} Trunc< T : FloatingPoint & Differentiable > (
129- _ x: T
130- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
131+ func _${ derivative_kind} Trunc$ { generic_signature } (
132+ _ x: $ { T }
133+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
131134 return ( trunc ( x) , { v in 0 } )
132135}
133- % end # for derivative_kind in [ 'jvp', 'vjp'] :
136+ % end # for derivative_kind in [ 'jvp', 'vjp'] :
137+ % end # for T in [ 'T', 'Double'] :
134138
135139// Unary functions
136140% for derivative_kind in [ 'jvp', 'vjp'] :
@@ -276,7 +280,7 @@ func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${
276280% end # for derivative_kind in [ 'jvp', 'vjp'] :
277281
278282// Binary functions
279- % for T in [ 'Float', 'Float80 '] :
283+ % for T in [ 'Float', 'Double' , ' Float80 '] :
280284% if T == 'Float80 ':
281285#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
282286% end
@@ -300,4 +304,4 @@ func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -
300304% if T == 'Float80 ':
301305#endif
302306% end # if T == 'Float80 ':
303- % end # for T in [ 'Float', 'Float80 '] :
307+ % end # for T in [ 'Float', 'Double' , ' Float80 '] :
0 commit comments