@@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex
7474 i:: Int
7575end
7676
77- function Base. getindex (a:: AbstractTangentBundle , b:: TaylorTangentIndex )
78- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
79- end
80-
8177abstract type AbstractTangentSpace; end
78+ Base.:(== )(x:: AbstractTangentSpace , y:: AbstractTangentSpace ) = == (promote (x, y)... )
8279
8380"""
8481 struct ExplicitTangent{P}
@@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials.
8986struct ExplicitTangent{P <: Tuple } <: AbstractTangentSpace
9087 partials:: P
9188end
89+ Base.:(== )(a:: ExplicitTangent , b:: ExplicitTangent ) = a. partials == b. partials
90+ Base. hash (tt:: ExplicitTangent , h:: UInt64 ) = hash (tt. partials, h)
91+
92+ Base. getindex (tangent:: ExplicitTangent , b:: CanonicalTangentIndex ) = tangent. partials[b. i]
93+ function Base. getindex (tangent:: ExplicitTangent , b:: TaylorTangentIndex )
94+ if lastindex (tangent. partials) == exp2 (b. i) - 1
95+ return tangent. partials[end ]
96+ end
97+ # TODO : should we also allow other indexes if all the partials at that level are equal up regardless of order?
98+ throw (DomainError (b, " $(typeof (tangent)) is not taylor-like. Taylor indexing is ambiguous" ))
99+ end
100+
92101
93102@eval struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
94103 coeffs:: C
95104 TaylorTangent (coeffs) = $ (Expr (:new , :(TaylorTangent{typeof (coeffs)}), :coeffs ))
96105end
97- Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
98- Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
99106
100107"""
101108 struct TaylorTangent{C}
@@ -122,15 +129,13 @@ by analogy with the (truncated) Taylor series
122129"""
123130TaylorTangent
124131
125- """
126- struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
132+ Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
133+ Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
134+
135+
136+ Base. getindex (tangent:: TaylorTangent , tti:: TaylorTangentIndex ) = tangent. coeffs[tti. i]
137+ Base. getindex (tangent:: TaylorTangent , tti:: CanonicalTangentIndex ) = tangent. coeffs[count_ones (tti. i)]
127138
128- Represents the product space of the given representations of the
129- tangent space.
130- """
131- struct ProductTangent{T <: Tuple } <: AbstractTangentSpace
132- factors:: T
133- end
134139
135140"""
136141 struct UniformTangent
@@ -141,6 +146,28 @@ useful for representing singleton values.
141146struct UniformTangent{U} <: AbstractTangentSpace
142147 val:: U
143148end
149+ Base. hash (t:: UniformTangent , h:: UInt64 ) = hash (t. val, h)
150+ Base.:(== )(t1:: UniformTangent , t2:: UniformTangent ) = t1. val == t2. val
151+
152+ Base. getindex (tangent:: UniformTangent , :: Any ) = tangent. val
153+
154+ # Conversion and promotion
155+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:AbstractTangentSpace} ) = et
156+ Base. promote_rule (tt:: Type{<:TaylorTangent} , :: Type{<:AbstractTangentSpace} ) = tt
157+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:TaylorTangent} ) = et
158+ Base. promote_rule (:: Type{<:TaylorTangent} , et:: Type{<:ExplicitTangent} ) = et
159+
160+ num_partials (:: Type{TaylorTangent{P}} ) where P = fieldcount (P)
161+ num_partials (:: Type{ExplicitTangent{P}} ) where P = fieldcount (P)
162+ Base. eltype (:: Type{TaylorTangent{P}} ) where P = eltype (P)
163+ Base. eltype (:: Type{ExplicitTangent{P}} ) where P = eltype (P)
164+ function Base. convert (:: Type{T} , ut:: UniformTangent ) where {T<: Union{TaylorTangent, ExplicitTangent} }
165+ # can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
166+ T_wrapper = T<: TaylorTangent ? TaylorTangent : ExplicitTangent
167+ T_wrapper (ntuple (_-> convert (eltype (T), ut. val), num_partials (T)))
168+ end
169+ Base. convert (T:: Type{<:ExplicitTangent} , tt:: TaylorTangent ) = ExplicitTangent (ntuple (i-> tt[CanonicalTangentIndex (i)], num_partials (T)))
170+ # TODO : Should we define the reverse: Explict->Taylor for the cases where that is actually defined?
144171
145172function _TangentBundle end
146173
@@ -154,15 +181,17 @@ end
154181 struct TangentBundle{N, B, P}
155182
156183Represents a tangent bundle as an explicit primal together
157- with some representation of (potentially a product of) the tangent space.
184+ with some representation of the tangent space.
158185"""
159186TangentBundle
160187
161188TangentBundle {N} (primal:: B , tangent:: P ) where {N, B, P<: AbstractTangentSpace } =
162189 _TangentBundle (Val {N} (), primal, tangent)
163190
164191Base. hash (tb:: TangentBundle , h:: UInt64 ) = hash (tb. primal, h)
165- Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = (a. primal == b. primal) && (a. tangent == b. tangent)
192+ Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = false # different orders
193+ Base.:(== )(a:: TangentBundle{N} , b:: TangentBundle{N} ) where {N} = (a. primal == b. primal) && (a. tangent == b. tangent)
194+ Base. getindex (tbun:: TangentBundle , x) = getindex (tbun. tangent, x)
166195
167196const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
168197
@@ -197,12 +226,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle)
197226 length (x. partials) >= 7 && print (io, " + " , x. partials[7 ], " ∂₁ ∂₂ ∂₃" )
198227end
199228
200- function Base. getindex (a:: ExplicitTangentBundle{N} , b:: TaylorTangentIndex ) where {N}
201- if b. i === N
202- return a. tangent. partials[end ]
203- end
204- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
205- end
229+
206230
207231const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208232
@@ -233,11 +257,6 @@ function Base.show(io::IO, x::TaylorBundle{1})
233257 print (io, x. coeffs[1 ], " ∂₁" )
234258end
235259
236- Base. getindex (tb:: TaylorBundle , tti:: TaylorTangentIndex ) = tb. tangent. coeffs[tti. i]
237- function Base. getindex (tb:: TaylorBundle , tti:: CanonicalTangentIndex )
238- tb. tangent. coeffs[count_ones (tti. i)]
239- end
240-
241260" for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
242261function destructure (r:: TaylorBundle{N, B} ) where {N, B<: Tuple }
243262 return ntuple (fieldcount (B)) do field_ii
@@ -307,8 +326,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N
307326 print (io, " )" )
308327end
309328
329+ # Conversion and promotion
330+ function Base. promote_rule (:: Type{TangentBundle{N, B, P1}} , :: Type{TangentBundle{N, B, P2}} ) where {N,B,P1,P2}
331+ return TangentBundle{N, B, promote_type (P1, P2)}
332+ end
333+
334+ function Base. convert (:: Type{T} , tbun:: TangentBundle{N, B} ) where {N, B, P, T<: TangentBundle{N,B,P} }
335+ the_primal = convert (B, primal (tbun))
336+ the_partials = convert (P, tbun. tangent)
337+ return _TangentBundle (Val {N} (), the_primal, the_partials)
338+ end
310339
311- Base . getindex (u :: UniformBundle , :: TaylorTangentIndex ) = u . tangent . val
340+ # StructureArrays helpers
312341
313342expand_singleton_to_array (asize, a:: AbstractZero ) = fill (a, asize... )
314343expand_singleton_to_array (asize, a:: AbstractArray ) = a
0 commit comments