@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
22
33@info " IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
44
5+ import Base: show
6+
57using DifferentialEquations
68import DifferentialEquations: solve
79
@@ -15,6 +17,7 @@ using DocStringExtensions
1517
1618export DERelative
1719
20+ import Manifolds: allocate
1821
1922
2023getManifold (de:: DERelative{T} ) where {T} = getManifold (de. domain)
@@ -100,11 +103,11 @@ function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
100103 # happens when more variables (n-ary) must be included in DE solve
101104 for (xid, xtra) in enumerate (Xtra)
102105 # update the data register before ODE solver calls the function
103- prob. p[xid + 1 ][:] = xtra[:]
106+ prob. p[xid + 1 ][:] = xtra[:] # FIXME , unlikely to work with ArrayPartition, maybe use MArray and `.=`
104107 end
105108
106109 # set the initial condition
107- prob. u0 = u0pts
110+ prob. u0 . = u0pts
108111
109112 sol = DifferentialEquations. solve (prob)
110113
@@ -250,8 +253,10 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
250253 oder = cf. factor
251254
252255 # how many trajectories to propagate?
253- # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
254- meas = [zeros (getDimension (cf. fullvariables[2 ])) for _ = 1 : N]
256+ #
257+ v2T = getVariableType (cf. fullvariables[2 ])
258+ meas = [allocate (getPointIdentity (v2T)) for _ = 1 : N]
259+ # meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
255260
256261 # pick forward or backward direction
257262 # set boundary condition
288293
289294
290295
296+ function Base. show (io:: IO , :: Union{<:DERelative{T,O},Type{<:DERelative{T,O}}} ) where {T,O}
297+ println (io, " DERelative{" )
298+ println (io, " " , T)
299+ println (io, " " , O. name. name)
300+ println (io, " }" )
301+ nothing
302+ end
291303
304+ Base. show (io:: IO , :: MIME"text/plain" , der:: DERelative ) = show (io, der)
292305
293306# # the function
294307# ode.problem.f.f
0 commit comments