1- # ================================================================================================
2- # # FactorOperationalMemory for parametric, TODO move back to FactorOperationalMemory.jl
3- # # ================================================================================================
4-
5-
6- # struct CalcFactorMahalanobis{CF<:CalcFactor, S<:Union{Nothing,AbstractMaxMixtureSolver}, N}
7- # calcfactor!::CF
8- # varOrder::Vector{Symbol}
9- # meas::NTuple{N, <:AbstractArray}
10- # iΣ::NTuple{N, Matrix{Float64}}
11- # specialAlg::S
12- # end
13-
141# ================================================================================================
152# # FlatVariables - used for packing variables for optimization
163# # ================================================================================================
5643 $SIGNATURES
5744
5845Returns the parametric measurement for a factor as a tuple (measurement, inverse covariance) for parametric inference (assuming Gaussian).
59- Defaults to find the parametric measurement at field `Z`, fields `Zij` and `z` are deprecated for standardization.
60-
46+ Defaults to find the parametric measurement at field `Z`.
6147Notes
6248- Users should overload this method should their factor not default to `.Z<:ParametricType`.
6349- First design choice was to restrict this function to returning coordinates
@@ -99,7 +85,6 @@ function getMeasurementParametric(s::AbstractFactor)
9985 if hasfield (typeof (s), :Z )
10086 Z = s. Z
10187 else
102- @warn " getMeasurementParametric falls back to using field `.Z` by default. Extend it for more complex factors."
10388 error (
10489 " getMeasurementParametric(::$(typeof (s)) ) not defined, please add it, or use non-parametric, or open an issue for help." ,
10590 )
11196getMeasurementParametric (fct:: DFGFactor ) = getMeasurementParametric (getFactorType (fct))
11297getMeasurementParametric (dfg:: AbstractDFG , flb:: Symbol ) = getMeasurementParametric (getFactor (dfg, flb))
11398
99+ # maybe rename getMeasurementParametric to something like getNormalDistributionParams or getMeanCov
100+
101+ # default to point on manifold
102+ function getFactorMeasurementParametric (fac:: AbstractPrior )
103+ M = getManifold (fac)
104+ ϵ = getPointIdentity (M)
105+ dims = manifold_dimension (M)
106+ Xc, iΣ = getMeasurementParametric (fac)
107+ X = get_vector (M, ϵ, Xc, DefaultOrthogonalBasis ())
108+ meas = convert (typeof (ϵ), exp (M, ϵ, X))
109+ iΣ = convert (SMatrix{dims, dims}, iΣ)
110+ meas, iΣ
111+ end
112+ # default to point on tangent vector
113+ function getFactorMeasurementParametric (fac:: AbstractRelative )
114+ M = getManifold (fac)
115+ ϵ = getPointIdentity (M)
116+ dims = manifold_dimension (M)
117+ Xc, iΣ = getMeasurementParametric (fac)
118+ measX = convert (typeof (ϵ), get_vector (M, ϵ, Xc, DefaultOrthogonalBasis ()))
119+ iΣ = convert (SMatrix{dims, dims}, iΣ)
120+ measX, iΣ
121+ end
122+
123+ getFactorMeasurementParametric (fct:: DFGFactor ) = getFactorMeasurementParametric (getFactorType (fct))
124+ getFactorMeasurementParametric (dfg:: AbstractDFG , flb:: Symbol ) = getFactorMeasurementParametric (getFactor (dfg, flb))
125+
114126# # ================================================================================================
115127# # Parametric solve with Mahalanobis distance - CalcFactor
116128# # ================================================================================================
@@ -124,41 +136,18 @@ function CalcFactorMahalanobis(fg, fct::DFGFactor)
124136 varOrder = getVariableOrder (fct)
125137
126138 # NOTE, use getMeasurementParametric on DFGFactor{<:CCW} to allow special cases like OAS factors
127- _meas, _iΣ = getMeasurementParametric (fct) # fac_func
128- M = getManifold (getFactorType (fct))
129- dims = manifold_dimension (M)
130- ϵ = getPointIdentity (M)
131-
132- _measX = if typeof (_meas) <: Tuple
133- # TODO perhaps better consolidate manifold prior
134- map (m -> hat (M, ϵ, m), _meas)
135- elseif fac_func isa ManifoldPrior
136- (_meas,)
137- else
138- (convert (typeof (ϵ), get_vector (M, ϵ, _meas, DefaultOrthogonalBasis ())),)
139- end
140-
141- meas = fac_func isa AbstractPrior ? map (X -> exp (M, ϵ, X), _measX) : _measX
142-
143- iΣ = convert .(SMatrix{dims, dims}, typeof (_iΣ) <: Tuple ? _iΣ : (_iΣ,))
139+ _meas, _iΣ = getFactorMeasurementParametric (fct) # fac_func
140+
141+ # make sure its a tuple TODO Fix with mixture rework #1504
142+ meas = typeof (_meas) <: Tuple ? _meas : (_meas,)
143+ iΣ = typeof (_iΣ) <: Tuple ? _iΣ : (_iΣ,)
144144
145145 cache = preambleCache (fg, getVariable .(fg, varOrder), getFactorType (fct))
146146
147- calcf = CalcFactor (
148- getFactorMechanics (fac_func),
149- 0 ,
150- nothing ,
151- true ,
152- cache,
153- (), # DFGVariable[],
154- 0 ,
155- getManifold (_getCCW (fct)) # getManifold(fac_func)
156- )
157-
158147 multihypo = getSolverData (fct). multihypo
159148 nullhypo = getSolverData (fct). nullhypo
160149
161- # FIXME , type instability, use dispatch instead of if-else
150+ # FIXME , type instability
162151 if length (multihypo) > 0
163152 special = MaxMultihypo (multihypo)
164153 elseif nullhypo > 0
@@ -169,16 +158,22 @@ function CalcFactorMahalanobis(fg, fct::DFGFactor)
169158 special = nothing
170159 end
171160
172- return CalcFactorMahalanobis (fct. label, calcf , varOrder, meas, iΣ, special)
161+ return CalcFactorMahalanobis (fct. label, getFactorMechanics (fac_func), cache , varOrder, meas, iΣ, special)
173162end
174163
175164# This is where the actual parametric calculation happens, CalcFactor equivalent for parametric
176- @inline function (cfp:: CalcFactorMahalanobis{1, D, L, Nothing} )(variables... ) where {D, L}# AbstractArray{T} where T <: Real
177- # call the user function
178- res = cfp. calcfactor! (cfp. meas... , variables... )
179- # 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
180- return res' * cfp. iΣ[1 ] * res
181- end
165+ # function (cfp::CalcFactorMahalanobis{FT, 1, C, MEAS, D, L, Nothing})(variables...) where {FT, C, MEAS, D, L, Nothing}# AbstractArray{T} where T <: Real
166+ # # call the user function
167+ # res = cfp.calcfactor!(cfp.meas..., variables...)
168+ # # 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
169+ # return res' * cfp.iΣ[1] * res
170+ # end
171+
172+ # function (cfm::CalcFactorMahalanobis)(variables...)
173+ # meas = cfm.meas
174+ # points = map(idx->p[idx], cfm.varOrderIdxs)
175+ # return cfm.sqrt_iΣ * cfm(meas, points...)
176+ # end
182177
183178function calcFactorMahalanobisDict (fg)
184179 calcFactors = OrderedDict {Symbol, CalcFactorMahalanobis} ()
@@ -190,20 +185,39 @@ function calcFactorMahalanobisDict(fg)
190185 return calcFactors
191186end
192187
193- # Base.eltype(::Type{<:CalcFactorMahalanobis}) = CalcFactorMahalanobis
188+ function getFactorTypesCount (facs:: Vector{<:DFGFactor} )
189+ typedict = OrderedDict {DataType, Int} ()
190+ alltypes = OrderedDict {DataType, Vector{Symbol}} ()
191+ for f in facs
192+ facType = typeof (getFactorType (f))
193+ cnt = get! (typedict, facType, 0 )
194+ typedict[facType] = cnt + 1
194195
195- # function calcFactorMahalanobisArray(fg)
196- # cfps = map(getFactors(fg)) do fct
197- # CalcFactorMahalanobis(fg, fct)
198- # end
199- # types = collect(Set(typeof.(cfps)))
200- # cfparr = ArrayPartition(map(x->Vector{x}(), types)...)
201- # for cfp in cfps
202- # idx = findfirst(==(typeof(cfp)), types)
203- # push!(cfparr.x[idx], cfp)
204- # end
205- # return cfparr
206- # end
196+ dt = get! (alltypes, facType, Symbol[])
197+ push! (dt, f. label)
198+ end
199+ # TODO tuple or vector?
200+ # vartypes = tuple(keys(typedict)...)
201+ factypes:: Vector{DataType} = collect (keys (typedict))
202+ return factypes, typedict, alltypes
203+ end
204+
205+ function calcFactorMahalanobisVec (fg)
206+ factypes, typedict, alltypes = getFactorTypesCount (getFactors (fg))
207+
208+ # skip non-numeric prior (MetaPrior)
209+ # TODO test... remove MetaPrior{T} something like this
210+ metaPriorKeys = filter (k-> contains (string (k), " MetaPrior" ), collect (keys (alltypes)))
211+ delete! .(Ref (alltypes), metaPriorKeys)
212+
213+ parts = map (values (alltypes)) do labels
214+ map (getFactor .(fg, labels)) do fct
215+ CalcFactorMahalanobis (fg, fct)
216+ end
217+ end
218+ parts_tuple = (parts... ,)
219+ return ArrayPartition {CalcFactorMahalanobis, typeof(parts_tuple)} (parts_tuple)
220+ end
207221
208222# # ================================================================================================
209223# # ================================================================================================
@@ -265,8 +279,10 @@ function getVariableTypesCount(vars::Vector{<:DFGVariable})
265279 return vartypes, typedict, alltypes
266280end
267281
268- function buildGraphSolveManifold (fg:: AbstractDFG )
269- vartypes, vartypecount, vartypeslist = getVariableTypesCount (fg)
282+ buildGraphSolveManifold (fg:: AbstractDFG ) = buildGraphSolveManifold (getVariables (fg))
283+
284+ function buildGraphSolveManifold (vars:: Vector{<:DFGVariable} )
285+ vartypes, vartypecount, vartypeslist = getVariableTypesCount (vars)
270286
271287 PMs = map (vartypes) do vartype
272288 N = vartypecount[vartype]
@@ -294,34 +310,32 @@ function GraphSolveBuffers(@nospecialize(M), ::Type{T}) where {T}
294310 return GraphSolveBuffers (ϵ, p, X, Xc)
295311end
296312
297- struct GraphSolveContainer
313+ struct GraphSolveContainer{CFT}
298314 M:: AbstractManifold # ProductManifold or ProductGroup
299315 buffers:: OrderedDict{DataType, GraphSolveBuffers}
300316 varTypes:: Vector{DataType}
301317 varTypesIds:: OrderedDict{DataType, Vector{Symbol}}
302- cfdict:: OrderedDict{Symbol, CalcFactorMahalanobis}
303318 varOrderDict:: OrderedDict{Symbol, Tuple{Int, Vararg{Int}}}
304- # cfarr::AbstractVector # TODO maybe <: AbstractVector( CalcFactorMahalanobis)
319+ cfv :: ArrayPartition{ CalcFactorMahalanobis, CFT}
305320end
306321
307322function GraphSolveContainer (fg)
308323 M, varTypes, varTypesIds = buildGraphSolveManifold (fg)
309324 varTypesIndexes = ArrayPartition (values (varTypesIds)... )
310325 buffs = OrderedDict {DataType, GraphSolveBuffers} ()
311- cfd = calcFactorMahalanobisDict (fg)
326+ cfvec = calcFactorMahalanobisVec (fg)
312327
313328 varOrderDict = OrderedDict {Symbol, Tuple{Int, Vararg{Int}}} ()
314- for (fid, cfp) in cfd
329+ for cfp in cfvec
330+ fid = cfp. faclbl
315331 varOrder = cfp. varOrder
316332 var_idx = map (varOrder) do v
317333 return findfirst (== (v), varTypesIndexes)
318334 end
319335 varOrderDict[fid] = tuple (var_idx... )
320336 end
321337
322- # cfarr = calcFactorMahalanobisArray(fg)
323- # return GraphSolveContainer(M, buffs, varTypes, varTypesIds, cfd, varOrderDict, cfarr)
324- return GraphSolveContainer (M, buffs, varTypes, varTypesIds, cfd, varOrderDict)
338+ return GraphSolveContainer (M, buffs, varTypes, varTypesIds, varOrderDict, cfvec)
325339end
326340
327341function getGraphSolveCache! (gsc:: GraphSolveContainer , :: Type{T} ) where {T <: Real }
@@ -348,11 +362,15 @@ function _toPoints2!(
348362end
349363
350364function cost_cfp (
351- @nospecialize ( cfp:: CalcFactorMahalanobis ) ,
352- @nospecialize ( p:: AbstractArray ) ,
365+ cfp:: CalcFactorMahalanobis ,
366+ p:: AbstractArray{T} ,
353367 vi:: NTuple{N, Int} ,
354- ) where N
355- cfp (map (v-> p[v],vi)... )
368+ ) where {T,N}
369+ # cfp(map(v->p[v],vi)...)
370+ res = cfp (cfp. meas... , map (v-> p[v],vi)... )
371+ # 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
372+ return res' * cfp. iΣ[1 ] * res
373+
356374end
357375# function cost_cfp(
358376# @nospecialize(cfp::CalcFactorMahalanobis),
@@ -403,17 +421,19 @@ function (gsc::GraphSolveContainer)(Xc::Vector{T}) where {T <: Real}
403421 #
404422 buffs = getGraphSolveCache! (gsc, T)
405423
406- cfdict = gsc. cfdict
407424 varOrderDict = gsc. varOrderDict
408425
409426 M = gsc. M
410427
411428 p = _toPoints2! (M, buffs, Xc)
412-
413- obj = mapreduce (+ , cfdict) do (fid, cfp)
414- varOrder_idx = varOrderDict[fid]
415- # call the user function
416- return cost_cfp (cfp, p, varOrder_idx)
429+
430+ obj = mapreduce (+ , eachindex (gsc. cfv)) do i
431+ cfp = gsc. cfv[i]
432+ varOrder_idx = varOrderDict[cfp. faclbl]
433+ # # call the user function
434+ cost:: T = cost_cfp (cfp, p, varOrder_idx)
435+
436+ return cost
417437 end
418438
419439 return obj / 2
@@ -507,6 +527,7 @@ function solveGraphParametric(
507527 autodiff = :forward ,
508528 algorithm = Optim. BFGS,
509529 algorithmkwargs = (), # add manifold to overwrite computed one
530+ # algorithmkwargs = (linesearch=Optim.BackTracking(),), # add manifold to overwrite computed one
510531 options = Optim. Options (;
511532 allow_f_increases = true ,
512533 time_limit = 100 ,
@@ -539,22 +560,7 @@ function solveGraphParametric(
539560
540561 # optim setup and solve
541562 alg = algorithm (; algorithmkwargs... )
542- # alg = NewtonTrustRegion(;
543- # initial_delta = 1.0,
544- # delta_hat = 100.0,
545- # eta = 0.1,
546- # rho_lower = 0.25,
547- # rho_upper = 0.75
548- # )
549- # alg = LBFGS(;
550- # m = 10,
551- # alphaguess = LineSearches.InitialStatic(),
552- # linesearch = LineSearches.HagerZhang(),
553- # P = nothing,
554- # precondprep = (P, x) -> nothing,
555- # manifold = Flat(),
556- # scaleinvH0::Bool = true && (typeof(P) <: Nothing)
557- # )
563+
558564 tdtotalCost = Optim. TwiceDifferentiable (gsc, initValues; autodiff = autodiff)
559565
560566 result = Optim. optimize (tdtotalCost, initValues, alg, options)
@@ -609,10 +615,10 @@ function _totalCost(fg, cfdict::OrderedDict{Symbol, <:CalcFactorMahalanobis}, fl
609615 ]
610616
611617 # call the user function
612- retval = cfp (Xparams... )
613-
618+ # retval = cfp(Xparams...)
619+ res = cfp (cfp . meas ... , Xparams ... )
614620 # 1/2*log(1/( sqrt(det(Σ)*(2pi)^k) )) ## k = dim(μ)
615- obj += 1 / 2 * retval
621+ obj += 1 / 2 * res ' * cfp . iΣ[ 1 ] * res
616622 end
617623
618624 return obj
890896 $SIGNATURES
891897Update the fg from solution in vardict and add MeanMaxPPE (all just mean). Usefull for plotting
892898"""
893- function updateParametricSolution! (sfg, vardict; solveKey:: Symbol = :parametric )
899+ function updateParametricSolution! (sfg, vardict:: AbstractDict ; solveKey:: Symbol = :parametric )
894900 for (v, val) in vardict
895901 vnd = getSolverData (getVariable (sfg, v), solveKey)
896902 # Update the variable node data value and covariance
@@ -902,6 +908,18 @@ function updateParametricSolution!(sfg, vardict; solveKey::Symbol = :parametric)
902908 end
903909end
904910
911+ function updateParametricSolution! (sfg, labels:: AbstractArray{Symbol} , vals; solveKey:: Symbol = :parametric )
912+ for (v, val) in zip (labels, vals)
913+ vnd = getSolverData (getVariable (sfg, v), solveKey)
914+ # Update the variable node data value and covariance
915+ updateSolverDataParametric! (vnd, val, vnd. bw)# FIXME add cov
916+ # fill in ppe as mean
917+ Xc = collect (getCoordinates (getVariableType (sfg, v), val))
918+ ppe = MeanMaxPPE (solveKey, Xc, Xc, Xc)
919+ getPPEDict (getVariable (sfg, v))[solveKey] = ppe
920+ end
921+ end
922+
905923function createMvNormal (val, cov)
906924 # TODO do something better for properly formed covariance, but for now just a hack...FIXME
907925 if all (diag (cov) .> 0.001 ) && isapprox (cov, transpose (cov); rtol = 1e-4 )
@@ -939,9 +957,10 @@ function autoinitParametric!(
939957 reinit = false ,
940958 algorithm = Optim. NelderMead,
941959 algorithmkwargs = (initial_simplex = Optim. AffineSimplexer (0.025 , 0.1 ),),
960+ kwargs...
942961)
943962 @showprogress for vIdx in varorderIds
944- autoinitParametric! (fg, vIdx; reinit, algorithm, algorithmkwargs)
963+ autoinitParametric! (fg, vIdx; reinit, algorithm, algorithmkwargs, kwargs ... )
945964 end
946965 return nothing
947966end
0 commit comments