@@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex
7474 i:: Int
7575end
7676
77- function Base. getindex (a:: AbstractTangentBundle , b:: TaylorTangentIndex )
78- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
79- end
80-
8177abstract type AbstractTangentSpace; end
78+ Base.:(== )(x:: AbstractTangentSpace , y:: AbstractTangentSpace ) = == (promote (x, y)... )
8279
8380"""
8481 struct ExplicitTangent{P}
@@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials.
8986struct ExplicitTangent{P <: Tuple } <: AbstractTangentSpace
9087 partials:: P
9188end
89+ Base.:(== )(a:: ExplicitTangent , b:: ExplicitTangent ) = a. partials == b. partials
90+ Base. hash (tt:: ExplicitTangent , h:: UInt64 ) = hash (tt. partials, h)
91+
92+ Base. getindex (tangent:: ExplicitTangent , b:: CanonicalTangentIndex ) = tangent. partials[b. i]
93+ function Base. getindex (tangent:: ExplicitTangent , b:: TaylorTangentIndex )
94+ if lastindex (tangent. partials) == exp2 (b. i) - 1
95+ return tangent. partials[end ]
96+ end
97+ # TODO : should we also allow other indexes if all the partials at that level are equal up regardless of order?
98+ throw (DomainError (b, " $(typeof (tangent)) is not taylor-like. Taylor indexing is ambiguous" ))
99+ end
100+
92101
93102@eval struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
94103 coeffs:: C
95104 TaylorTangent (coeffs) = $ (Expr (:new , :(TaylorTangent{typeof (coeffs)}), :coeffs ))
96105end
97- Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
98- Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
99106
100107"""
101108 struct TaylorTangent{C}
@@ -122,6 +129,14 @@ by analogy with the (truncated) Taylor series
122129"""
123130TaylorTangent
124131
132+ Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
133+ Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
134+
135+
136+ Base. getindex (tangent:: TaylorTangent , tti:: TaylorTangentIndex ) = tangent. coeffs[tti. i]
137+ Base. getindex (tangent:: TaylorTangent , tti:: CanonicalTangentIndex ) = tangent. coeffs[count_ones (tti. i)]
138+
139+
125140"""
126141 struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
127142
@@ -141,6 +156,28 @@ useful for representing singleton values.
141156struct UniformTangent{U} <: AbstractTangentSpace
142157 val:: U
143158end
159+ Base. hash (t:: UniformTangent , h:: UInt64 ) = hash (t. val, h)
160+ Base.:(== )(t1:: UniformTangent , t2:: UniformTangent ) = t1. val == t2. val
161+
162+ Base. getindex (tangent:: UniformTangent , :: Any ) = tangent. val
163+
164+ # Conversion and promotion
165+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:AbstractTangentSpace} ) = et
166+ Base. promote_rule (tt:: Type{<:TaylorTangent} , :: Type{<:AbstractTangentSpace} ) = tt
167+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:TaylorTangent} ) = et
168+ Base. promote_rule (:: Type{<:TaylorTangent} , et:: Type{<:ExplicitTangent} ) = et
169+
170+ num_partials (:: Type{TaylorTangent{P}} ) where P = fieldcount (P)
171+ num_partials (:: Type{ExplicitTangent{P}} ) where P = fieldcount (P)
172+ Base. eltype (:: Type{TaylorTangent{P}} ) where P = eltype (P)
173+ Base. eltype (:: Type{ExplicitTangent{P}} ) where P = eltype (P)
174+ function Base. convert (:: Type{T} , ut:: UniformTangent ) where {T<: Union{TaylorTangent, ExplicitTangent} }
175+ # can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
176+ T_wrapper = T<: TaylorTangent ? TaylorTangent : ExplicitTangent
177+ T_wrapper (ntuple (_-> convert (eltype (T), ut. val), num_partials (T)))
178+ end
179+ Base. convert (T:: Type{<:ExplicitTangent} , tt:: TaylorTangent ) = ExplicitTangent (ntuple (i-> tt[CanonicalTangentIndex (i)], num_partials (T)))
180+ # TODO : Should we define the reverse: Explict->Taylor for the cases where that is actually defined?
144181
145182function _TangentBundle end
146183
@@ -162,7 +199,9 @@ TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
162199 _TangentBundle (Val {N} (), primal, tangent)
163200
164201Base. hash (tb:: TangentBundle , h:: UInt64 ) = hash (tb. primal, h)
165- Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = (a. primal == b. primal) && (a. tangent == b. tangent)
202+ Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = false # different orders
203+ Base.:(== )(a:: TangentBundle{N} , b:: TangentBundle{N} ) where {N} = (a. primal == b. primal) && (a. tangent == b. tangent)
204+ Base. getindex (tbun:: TangentBundle , x) = getindex (tbun. tangent, x)
166205
167206const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
168207
@@ -197,12 +236,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle)
197236 length (x. partials) >= 7 && print (io, " + " , x. partials[7 ], " ∂₁ ∂₂ ∂₃" )
198237end
199238
200- function Base. getindex (a:: ExplicitTangentBundle{N} , b:: TaylorTangentIndex ) where {N}
201- if b. i === N
202- return a. tangent. partials[end ]
203- end
204- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
205- end
239+
206240
207241const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208242
@@ -233,11 +267,6 @@ function Base.show(io::IO, x::TaylorBundle{1})
233267 print (io, x. coeffs[1 ], " ∂₁" )
234268end
235269
236- Base. getindex (tb:: TaylorBundle , tti:: TaylorTangentIndex ) = tb. tangent. coeffs[tti. i]
237- function Base. getindex (tb:: TaylorBundle , tti:: CanonicalTangentIndex )
238- tb. tangent. coeffs[count_ones (tti. i)]
239- end
240-
241270" for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
242271function destructure (r:: TaylorBundle{N, B} ) where {N, B<: Tuple }
243272 return ntuple (fieldcount (B)) do field_ii
@@ -307,8 +336,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N
307336 print (io, " )" )
308337end
309338
339+ # Conversion and promotion
340+ function Base. promote_rule (:: Type{TangentBundle{N, B, P1}} , :: Type{TangentBundle{N, B, P2}} ) where {N,B,P1,P2}
341+ return TangentBundle{N, B, promote_type (P1, P2)}
342+ end
343+
344+ function Base. convert (:: Type{T} , tbun:: TangentBundle{N, B} ) where {N, B, P, T<: TangentBundle{N,B,P} }
345+ the_primal = convert (B, primal (tbun))
346+ the_partials = convert (P, tbun. tangent)
347+ return _TangentBundle (Val {N} (), the_primal, the_partials)
348+ end
310349
311- Base . getindex (u :: UniformBundle , :: TaylorTangentIndex ) = u . tangent . val
350+ # StructureArrays helpers
312351
313352expand_singleton_to_array (asize, a:: AbstractZero ) = fill (a, asize... )
314353expand_singleton_to_array (asize, a:: AbstractArray ) = a
0 commit comments