1- # TODO : Compare with ChangesOfVariables.jl
1+ """
2+ abstract type PushFwdStyle
3+
4+ Provides the behavior of a measure's [`rootmeasure`](@ref) under a
5+ pushforward. Either [`AdaptRootMeasure()`](@ref) or
6+ [`PushfwdRootMeasure()`](@ref)
7+ """
8+ abstract type PushFwdStyle end
9+ export PushFwdStyle
10+
11+ const TransformVolCorr = PushFwdStyle
12+
13+ """
14+ AdaptRootMeasure()
15+
16+ Indicates that when applying a pushforward to a measure, it's
17+ [`rootmeasure`](@ref) not not be pushed forward. Instead, the root measure
18+ should be kept just "reshaped" to the new measurable space if necessary.
19+
20+ Density calculations for pushforward measures constructed with
21+ `AdaptRootMeasure()` will take take the volume element of variate
22+ transform (typically via the log-abs-det-Jacobian of the transform) into
23+ account.
24+ """
25+ struct AdaptRootMeasure <: TransformVolCorr end
26+ export AdaptRootMeasure
27+
28+ const WithVolCorr = AdaptRootMeasure
229
3- using InverseFunctions: FunctionWithInverse
30+ """
31+ PushfwdRootMeasure()
32+
33+ Indicates than when applying a pushforward to a measure, it's
34+ [`rootmeasure`](@ref) should be pushed forward with the same function.
35+
36+ Density calculations for pushforward measures constructed with
37+ `PushfwdRootMeasure()` will ignore the volume element of the variate
38+ transform.
39+ """
40+ struct PushfwdRootMeasure <: TransformVolCorr end
41+ export PushfwdRootMeasure
42+
43+ const NoVolCorr = PushfwdRootMeasure
444
545abstract type AbstractTransformedMeasure <: AbstractMeasure end
646
@@ -19,23 +59,42 @@ function parent(::AbstractTransformedMeasure) end
1959export PushforwardMeasure
2060
2161"""
22- struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr } <: AbstractPushforward
62+ struct PushforwardMeasure{F,I,M,S<:PushFwdStyle } <: AbstractPushforward
2363 f :: F
2464 finv :: I
2565 origin :: M
26- volcorr :: VC
66+ style :: S
2767 end
2868
2969 Users should not call `PushforwardMeasure` directly. Instead call or add
3070 methods to `pushfwd`.
3171"""
32- struct PushforwardMeasure{F,I,M,VC <: TransformVolCorr } <: AbstractPushforward
72+ struct PushforwardMeasure{F,I,M,S <: PushFwdStyle } <: AbstractPushforward
3373 f:: F
3474 finv:: I
3575 origin:: M
36- volcorr:: VC
76+ style:: S
77+
78+ function PushforwardMeasure {F,I,M,S} (
79+ f:: F ,
80+ finv:: I ,
81+ origin:: M ,
82+ style:: S ,
83+ ) where {F,I,M,S<: PushFwdStyle }
84+ new {F,I,M,S} (f, finv, origin, style)
85+ end
86+
87+ function PushforwardMeasure (f, finv, origin:: M , style:: S ) where {M,S<: PushFwdStyle }
88+ new {Core.Typeof(f),Core.Typeof(finv),M,S} (f, finv, origin, style)
89+ end
3790end
3891
92+ const _NonBijectivePusfwdMeasure{M<: PushforwardMeasure ,S<: PushFwdStyle } = Union{
93+ PushforwardMeasure{<: Any ,<: NoInverse ,M,S},
94+ PushforwardMeasure{<: NoInverse ,<: Any ,M,S},
95+ PushforwardMeasure{<: NoInverse ,<: NoInverse ,M,S},
96+ }
97+
3998gettransform (ν:: PushforwardMeasure ) = ν. f
4099parent (ν:: PushforwardMeasure ) = ν. origin
41100
45104
46105# TODO : THIS IS ALMOST CERTAINLY WRONG
47106# @inline function logdensity_rel(
48- # ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr },
49- # β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr },
107+ # ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure },
108+ # β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure },
50109# y,
51110# ) where {FF1,IF1,M1,FF2,IF2,M2}
52111# x = β.inv_f(y)
53112# f = ν.inv_f ∘ β.f
54113# inv_f = β.inv_f ∘ ν.f
55- # logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr ()), β.origin, x)
114+ # logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure ()), β.origin, x)
56115# end
57116
117+ # TODO : Would profit from custom pullback:
118+ function _combine_logd_with_ladj (logd_orig:: Real , ladj:: Real )
119+ logd_result = logd_orig + ladj
120+ R = typeof (logd_result)
121+
122+ if isnan (logd_result) && isneginf (logd_orig) && isposinf (ladj)
123+ # Zero μ wins against infinite volume:
124+ R (- Inf ):: R
125+ elseif isfinite (logd_orig) && isneginf (ladj)
126+ # Maybe also for isneginf(logd_orig) && isfinite(ladj) ?
127+ # Return constant -Inf to prevent problems with ForwardDiff:
128+ # R(-Inf)
129+ near_neg_inf (R):: R # Avoids AdvancedHMC warnings
130+ else
131+ logd_result:: R
132+ end
133+ end
134+
135+ function logdensityof (
136+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure} ),
137+ @nospecialize (v:: Any )
138+ ) where {M}
139+ throw (
140+ ArgumentError (
141+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
142+ ),
143+ )
144+ end
145+
146+ function logdensityof (
147+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure} ),
148+ @nospecialize (v:: Any )
149+ ) where {M}
150+ throw (
151+ ArgumentError (
152+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
153+ ),
154+ )
155+ end
156+
58157for func in [:logdensityof , :logdensity_def ]
59- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:WithVolCorr} , y) where {F,I,M}
60- f = ν. f
61- finv = ν. finv
62- x_orig, inv_ladj = with_logabsdet_jacobian (unwrap (finv), y)
63- logd_orig = $ func (ν. origin, x_orig)
64- logd = float (logd_orig + inv_ladj)
65- neginf = oftype (logd, - Inf )
66- return ifelse (
67- # Zero density wins against infinite volume:
68- (isnan (logd) && logd_orig == - Inf && inv_ladj == + Inf ) ||
69- # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70- # Return constant -Inf to prevent problems with ForwardDiff:
71- (isfinite (logd_orig) && (inv_ladj == - Inf )),
72- neginf,
73- logd,
74- )
158+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:AdaptRootMeasure} , y) where {F,I,M}
159+ f_inv = unwrap (ν. finv)
160+ x, inv_ladj = with_logabsdet_jacobian (f_inv, y)
161+ logd_orig = $ func (ν. origin, x)
162+ return _combine_logd_with_ladj (logd_orig, inv_ladj)
75163 end
76164
77- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:NoVolCorr} , y) where {F,I,M}
78- x = ν. finv (y)
79- return $ func (ν. origin, x)
165+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:PushfwdRootMeasure} , y) where {F,I,M}
166+ f_inv = unwrap (ν. finv)
167+ x = f_inv (y)
168+ logd_orig = $ func (ν. origin, x)
169+ return logd_orig
80170 end
81171end
82172
83- insupport (ν :: PushforwardMeasure , y ) = insupport (ν . origin, ν . finv (y ))
173+ insupport (m :: PushforwardMeasure , x ) = insupport (transport_origin (m), to_origin (m, x ))
84174
85175function testvalue (:: Type{T} , ν:: PushforwardMeasure ) where {T}
86176 ν. f (testvalue (T, parent (ν)))
87177end
88178
89179@inline function basemeasure (ν:: PushforwardMeasure )
90- pushfwd (ν. f, basemeasure (parent (ν)), NoVolCorr ())
180+ pushfwd (ν. f, basemeasure (parent (ν)), PushfwdRootMeasure ())
181+ end
182+
183+ function rootmeasure (m:: PushforwardMeasure{F,I,M,PushfwdRootMeasure} ) where {F,I,M}
184+ pushfwd (m. f, rootmeasure (m. origin))
185+ end
186+ function rootmeasure (m:: PushforwardMeasure{F,I,M,AdaptRootMeasure} ) where {F,I,M}
187+ rootmeasure (m. origin)
91188end
92189
93190_pushfwd_dof (:: Type{MU} , :: Type , dof) where {MU} = NoDOF {MU} ()
94191_pushfwd_dof (:: Type{MU} , :: Type{<:Tuple{Any,Real}} , dof) where {MU} = dof
95192
96193@inline getdof (ν:: MU ) where {MU<: PushforwardMeasure } = getdof (ν. origin)
194+ @inline getdof (m:: _NonBijectivePusfwdMeasure ) = MeasureBase. NoDOF {typeof(m)} ()
97195
98196# Bypass `checked_arg`, would require potentially costly transformation:
99197@inline checked_arg (:: PushforwardMeasure , x) = x
@@ -102,47 +200,53 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
102200@inline from_origin (ν:: PushforwardMeasure , x) = ν. f (x)
103201@inline to_origin (ν:: PushforwardMeasure , y) = ν. finv (y)
104202
203+ massof (m:: PushforwardMeasure ) = massof (transport_origin (m))
204+
105205function Base. rand (rng:: AbstractRNG , :: Type{T} , ν:: PushforwardMeasure ) where {T}
106- return ν. f (rand (rng, T, parent (ν) ))
206+ return ν. f (rand (rng, T, ν . origin ))
107207end
108208
109209# ##############################################################################
110210# pushfwd
111211
112- export pushfwd
113-
114212"""
115- pushfwd(f, μ, volcorr = WithVolCorr ())
213+ pushfwd(f, μ, style = AdaptRootMeasure ())
116214
117215Return the [pushforward
118216measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
119217[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
120218
121219To manually specify an inverse, call
122- `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr )`.
220+ `pushfwd(InverseFunctions.setinverse(f, finv), μ, style )`.
123221"""
124- function pushfwd (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
125- PushforwardMeasure (f, inverse (f), μ, volcorr)
126- end
127-
128- function pushfwd (f, μ:: PushforwardMeasure , volcorr:: TransformVolCorr = WithVolCorr ())
129- _pushfwd_of_pushfwd (f, μ, μ. volcorr, volcorr)
130- end
222+ function pushfwd end
223+ export pushfwd
131224
132- # Either both WithVolCorr or both NoVolCorr, so we can merge them
133- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , :: V , v:: V ) where {V}
134- pushfwd (fchain ((μ. f, f)), μ. origin, v)
225+ @inline pushfwd (f, μ) = _pushfwd_impl (f, μ, AdaptRootMeasure ())
226+ @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl (f, μ, style)
227+ @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl (f, μ, style)
228+
229+ _pushfwd_impl (f, μ, style) = PushforwardMeasure (f, inverse (f), μ, style)
230+
231+ function _pushfwd_impl (
232+ f,
233+ μ:: PushforwardMeasure{F,I,M,S} ,
234+ style:: S ,
235+ ) where {F,I,M,S<: PushFwdStyle }
236+ orig_μ = μ. origin
237+ new_f = fcomp (f, μ. f)
238+ new_f_inv = fcomp (μ. finv, inverse (f))
239+ PushforwardMeasure (new_f, new_f_inv, orig_μ, style)
135240end
136241
137- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , _, v)
138- PushforwardMeasure (f, inverse (f), μ, v)
139- end
242+ _pushfwd_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
243+ _pushfwd_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
140244
141245# ##############################################################################
142246# pullback
143247
144248"""
145- pullback (f, μ, volcorr = WithVolCorr ())
249+ pullbck (f, μ, style = AdaptRootMeasure ())
146250
147251A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
148252map _from_ the support of a measure, a pullback requires a map _into_ the
@@ -154,8 +258,17 @@ in terms of the inverse function; the "forward" function is not used at all. In
154258some cases, we may be focusing on log-density (and not, for example, sampling).
155259
156260To manually specify an inverse, call
157- `pullback (InverseFunctions.setinverse(f, finv), μ, volcorr )`.
261+ `pullbck (InverseFunctions.setinverse(f, finv), μ, style )`.
158262"""
159- function pullback (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
160- pushfwd (setinverse (inverse (f), f), μ, volcorr)
263+ function pullbck end
264+ export pullbck
265+
266+ @inline pullbck (f, μ) = _pullback_impl (f, μ, AdaptRootMeasure ())
267+ @inline pullbck (f, μ, style:: AdaptRootMeasure ) = _pullback_impl (f, μ, style)
268+ @inline pullbck (f, μ, style:: PushfwdRootMeasure ) = _pullback_impl (f, μ, style)
269+
270+ function _pullback_impl (f, μ, style = AdaptRootMeasure ())
271+ pushfwd (setinverse (inverse (f), f), μ, style)
161272end
273+
274+ @deprecate pullback (f, μ, style:: PushFwdStyle = AdaptRootMeasure ()) pullbck (f, μ, style)
0 commit comments