|
80 | 80 |
|
81 | 81 | abstract type AbstractTangentSpace; end |
82 | 82 |
|
83 | | -""" |
84 | | - struct ExplicitTangent{P} |
85 | | -
|
86 | | -A fully explicit coordinate representation of the tangent space, |
87 | | -represented by a vector of `2^(N-1)` partials. |
88 | | -""" |
89 | | -struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace |
90 | | - partials::P |
91 | | -end |
92 | | - |
93 | 83 | struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace |
94 | 84 | coeffs::C |
95 | 85 | end |
@@ -151,46 +141,9 @@ struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N |
151 | 141 | TangentBundle{N}(B, P) where {N} = new{N, typeof(B), typeof(P)}(B,P) |
152 | 142 | end |
153 | 143 |
|
154 | | -const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}} |
155 | | - |
156 | 144 | check_tangent_invariant(lp, N) = @assert lp == 2^N - 1 |
157 | 145 | @ChainRulesCore.non_differentiable check_tangent_invariant(lp, N) |
158 | 146 |
|
159 | | -function ExplicitTangentBundle{N}(primal::B, partials::P) where {N, B, P} |
160 | | - check_tangent_invariant(length(partials), N) |
161 | | - TangentBundle{N}(primal, ExplicitTangent{P}(partials)) |
162 | | -end |
163 | | - |
164 | | -function ExplicitTangentBundle{N,B}(primal::B, partials::P) where {N, B, P} |
165 | | - check_tangent_invariant(length(partials), N) |
166 | | - TangentBundle{N}(primal, ExplicitTangent{P}(partials)) |
167 | | -end |
168 | | - |
169 | | -function ExplicitTangentBundle{N,B,P}(primal::B, partials::P) where {N, B, P} |
170 | | - check_tangent_invariant(length(partials), N) |
171 | | - TangentBundle{N}(primal, ExplicitTangent{P}(partials)) |
172 | | -end |
173 | | - |
174 | | -function Base.show(io::IO, x::ExplicitTangentBundle) |
175 | | - print(io, x.primal) |
176 | | - print(io, " + ") |
177 | | - x = x.tangent |
178 | | - print(io, x.partials[1], " ∂₁") |
179 | | - length(x.partials) >= 2 && print(io, " + ", x.partials[2], " ∂₂") |
180 | | - length(x.partials) >= 3 && print(io, " + ", x.partials[3], " ∂₁ ∂₂") |
181 | | - length(x.partials) >= 4 && print(io, " + ", x.partials[4], " ∂₃") |
182 | | - length(x.partials) >= 5 && print(io, " + ", x.partials[5], " ∂₁ ∂₃") |
183 | | - length(x.partials) >= 6 && print(io, " + ", x.partials[6], " ∂₂ ∂₃") |
184 | | - length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃") |
185 | | -end |
186 | | - |
187 | | -function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N} |
188 | | - if b.i === N |
189 | | - return a.tangent.partials[end] |
190 | | - end |
191 | | - error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous") |
192 | | -end |
193 | | - |
194 | 147 | const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} |
195 | 148 |
|
196 | 149 | function TaylorBundle{N, B}(primal::B, coeffs) where {N, B} |
|
268 | 221 | expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) |
269 | 222 | expand_singleton_to_array(asize, a::AbstractArray) = a |
270 | 223 |
|
271 | | -function unbundle(atb::ExplicitTangentBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}} |
272 | | - asize = size(atb.primal) |
273 | | - StructArray{ExplicitTangentBundle{Order, T}}((atb.primal, map(a->expand_singleton_to_array(asize, a), atb.tangent.partials)...)) |
274 | | -end |
275 | | - |
276 | | -function StructArrays.staticschema(::Type{<:ExplicitTangentBundle{N, B, T}}) where {N, B, T} |
277 | | - Tuple{B, T.parameters...} |
278 | | -end |
279 | | - |
280 | | -function StructArrays.component(m::ExplicitTangentBundle{N, B, T}, i::Int) where {N, B, T} |
281 | | - i == 1 && return m.primal |
282 | | - return m.tangent.partials[i - 1] |
283 | | -end |
284 | | - |
285 | | -function StructArrays.createinstance(T::Type{<:ExplicitTangentBundle}, args...) |
286 | | - T(first(args), Base.tail(args)) |
287 | | -end |
288 | | - |
289 | 224 | function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}} |
290 | 225 | StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...)) |
291 | 226 | end |
@@ -323,14 +258,6 @@ function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...) |
323 | 258 | T(args[1], args[2]) |
324 | 259 | end |
325 | 260 |
|
326 | | -function rebundle(A::AbstractArray{<:ExplicitTangentBundle{N}}) where {N} |
327 | | - ExplicitTangentBundle{N}( |
328 | | - map(x->x.primal, A), |
329 | | - ntuple(2^N-1) do i |
330 | | - map(x->x.tangent.partials[i], A) |
331 | | - end) |
332 | | -end |
333 | | - |
334 | 261 | function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N} |
335 | 262 | TaylorBundle{N}( |
336 | 263 | map(x->x.primal, A), |
|
0 commit comments