Skip to content

Commit 973a5e4

Browse files
Merge pull request #241 from lxvm/infstability
Make frontend type stable
2 parents 278d70f + e6e9a5b commit 973a5e4

21 files changed

+524
-367
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
MonteCarloIntegration = "4886b29c-78c9-11e9-0a6e-41e1f4161f7b"
1111
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
12+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1314
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1415

@@ -46,8 +47,8 @@ HCubature = "1.5.2"
4647
LinearAlgebra = "1.10"
4748
MCIntegration = "0.4.2"
4849
MonteCarloIntegration = "0.2"
49-
Pkg = "1.10"
5050
QuadGK = "2.9"
51+
Random = "1.10"
5152
Reexport = "1.0"
5253
SafeTestsets = "0.1"
5354
SciMLBase = "2.24"
@@ -67,11 +68,10 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
6768
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
6869
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6970
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
70-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7171
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7272
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7373
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7474
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7575

7676
[targets]
77-
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]
77+
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]

docs/src/basics/solve.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,3 @@
33
```@docs
44
solve(prob::IntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm)
55
```
6-
7-
Additionally, the extra keyword arguments are splatted to the library calls, so
8-
see the documentation of the integrator library for all the extra details.
9-
These extra keyword arguments are not guaranteed to act uniformly.

ext/IntegralsArblibExt.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@ function Integrals.__solvebp_call(
1515

1616
if isinplace(prob)
1717
res = Acb(0)
18-
y_ = similar(prob.f.integrand_prototype, typeof(res))
18+
@assert res isa eltype(prob.f.integrand_prototype) "Arblib require inplace prototypes to store Acb elements"
19+
y_ = similar(prob.f.integrand_prototype)
1920
f_ = (y, x; kws...) -> (prob.f(y_, x, p; kws...); Arblib.set!(y, only(y_)))
2021
val = Arblib.integrate!(f_, res, lb, ub, atol = abstol, rtol = reltol,
2122
check_analytic = alg.check_analytic, take_prec = alg.take_prec,
2223
warn_on_no_convergence = alg.warn_on_no_convergence, opts = alg.opts)
23-
SciMLBase.build_solution(
24-
prob, alg, val, get_radius(val), retcode = ReturnCode.Success)
2524
else
2625
f_ = (x; kws...) -> only(prob.f(x, p; kws...))
2726
val = Arblib.integrate(f_, lb, ub, atol = abstol, rtol = reltol,
2827
check_analytic = alg.check_analytic, take_prec = alg.take_prec,
2928
warn_on_no_convergence = alg.warn_on_no_convergence, opts = alg.opts)
30-
SciMLBase.build_solution(
31-
prob, alg, val, get_radius(val), retcode = ReturnCode.Success)
3229
end
30+
SciMLBase.build_solution(prob, alg, val, get_radius(val), retcode = ReturnCode.Success)
3331
end
3432

3533
function get_radius(ball)

ext/IntegralsCubaExt.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori
1818
throw(ArgumentError("Cuba.jl only supports real-valued integrands"))
1919
# we could support other types by multiplying by the jacobian determinant at the end
2020

21-
if prob.f isa BatchIntegralFunction
22-
nvec = min(maxiters, prob.f.max_batch)
21+
f = prob.f
22+
prototype = Integrals.get_prototype(prob)
23+
if f isa BatchIntegralFunction
24+
fsize = size(prototype)[begin:(end - 1)]
25+
ncomp = prod(fsize)
26+
nvec = min(maxiters, f.max_batch)
2327
# nvec == 1 in Cuba will change vectors to matrices, so we won't support it when
2428
# batching
2529
nvec > 1 ||
@@ -33,24 +37,21 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori
3337
scale = x -> scale_x!(view(_x, :, 1:size(x, 2)), ub, lb, x)
3438
end
3539

36-
if isinplace(prob)
37-
fsize = size(prob.f.integrand_prototype)[begin:(end - 1)]
38-
y = similar(prob.f.integrand_prototype, fsize..., nvec)
39-
ax = map(_ -> (:), fsize)
40-
f = function (x, dx)
41-
dy = @view(y[ax..., begin:(begin + size(dx, 2) - 1)])
42-
prob.f(dy, scale(x), p)
43-
dx .= reshape(dy, :, size(dx, 2)) .* vol
40+
if isinplace(f)
41+
ax = ntuple(_ -> (:), length(fsize))
42+
_f = let y_ = similar(prototype, fsize..., nvec)
43+
function (u, _y)
44+
y = @view(y_[ax..., begin:(begin + size(_y, 2) - 1)])
45+
f(y, scale(u), p)
46+
_y .= reshape(y, size(_y)) .* vol
47+
end
4448
end
4549
else
46-
y = mid isa Number ? prob.f(typeof(mid)[], p) :
47-
prob.f(Matrix{typeof(mid)}(undef, length(mid), 0), p)
48-
fsize = size(y)[begin:(end - 1)]
49-
f = (x, dx) -> dx .= reshape(prob.f(scale(x), p), :, size(dx, 2)) .* vol
50+
_f = (u, y) -> y .= reshape(f(scale(u), p), size(y)) .* vol
5051
end
51-
ncomp = prod(fsize)
5252
else
5353
nvec = 1
54+
ncomp = length(prototype)
5455

5556
if mid isa Real
5657
scale = x -> scale_x(ub, lb, only(x))
@@ -59,58 +60,60 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori
5960
scale = x -> scale_x!(_x, ub, lb, x)
6061
end
6162

62-
if isinplace(prob)
63-
y = similar(prob.f.integrand_prototype)
64-
f = (x, dx) -> dx .= vec(prob.f(y, scale(x), p)) .* vol
63+
if isinplace(f)
64+
_f = let y = similar(prototype)
65+
(u, _y) -> begin
66+
f(y, scale(u), p)
67+
_y .= vec(y) .* vol
68+
end
69+
end
6570
else
66-
y = prob.f(mid, p)
67-
f = (x, dx) -> dx .= Iterators.flatten(prob.f(scale(x), p)) .* vol
71+
_f = (u, y) -> y .= Iterators.flatten(f(scale(u), p)) .* vol
6872
end
69-
ncomp = length(y)
7073
end
7174

72-
if alg isa CubaVegas
73-
out = Cuba.vegas(f, ndim, ncomp; rtol = reltol,
75+
out = if alg isa CubaVegas
76+
Cuba.vegas(_f, ndim, ncomp; rtol = reltol,
7477
atol = abstol, nvec = nvec,
7578
maxevals = maxiters,
7679
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
7780
nstart = alg.nstart, nincrease = alg.nincrease,
7881
gridno = alg.gridno)
7982
elseif alg isa CubaSUAVE
80-
out = Cuba.suave(f, ndim, ncomp; rtol = reltol,
83+
Cuba.suave(_f, ndim, ncomp; rtol = reltol,
8184
atol = abstol, nvec = nvec,
8285
maxevals = maxiters,
8386
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
8487
nnew = alg.nnew, nmin = alg.nmin, flatness = alg.flatness)
8588
elseif alg isa CubaDivonne
86-
out = Cuba.divonne(f, ndim, ncomp; rtol = reltol,
89+
Cuba.divonne(_f, ndim, ncomp; rtol = reltol,
8790
atol = abstol, nvec = nvec,
8891
maxevals = maxiters,
8992
flags = alg.flags, seed = alg.seed, minevals = alg.minevals,
9093
key1 = alg.key1, key2 = alg.key2, key3 = alg.key3,
9194
maxpass = alg.maxpass, border = alg.border,
9295
maxchisq = alg.maxchisq, mindeviation = alg.mindeviation)
9396
elseif alg isa CubaCuhre
94-
out = Cuba.cuhre(f, ndim, ncomp; rtol = reltol,
97+
Cuba.cuhre(_f, ndim, ncomp; rtol = reltol,
9598
atol = abstol, nvec = nvec,
9699
maxevals = maxiters,
97100
flags = alg.flags, minevals = alg.minevals, key = alg.key)
98101
end
99102

100103
# out.integral is a Vector{Float64}, but we want to return it to the shape of the integrand
101-
if prob.f isa BatchIntegralFunction
102-
if y isa AbstractVector
103-
val = out.integral[1]
104+
val = if f isa BatchIntegralFunction
105+
if prototype isa AbstractVector
106+
out.integral[1]
104107
else
105-
val = reshape(out.integral, fsize)
108+
reshape(out.integral, fsize)
106109
end
107110
else
108-
if y isa Real
109-
val = out.integral[1]
110-
elseif y isa AbstractVector
111-
val = out.integral
111+
if prototype isa Real
112+
out.integral[1]
113+
elseif prototype isa AbstractVector
114+
out.integral
112115
else
113-
val = reshape(out.integral, size(y))
116+
reshape(out.integral, size(prototype))
114117
end
115118
end
116119

ext/IntegralsCubatureExt.jl

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,143 +13,141 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
1313
mid = (lb + ub) / 2
1414

1515
# we get to pick fdim or not based on the IntegralFunction and its output dimensions
16-
y = if prob.f isa BatchIntegralFunction
17-
isinplace(prob.f) ? prob.f.integrand_prototype :
18-
mid isa Number ? prob.f(eltype(mid)[], p) :
19-
prob.f(Matrix{eltype(mid)}(undef, length(mid), 0), p)
20-
else
21-
# we evaluate the oop function to decide whether the output should be vectorized
22-
isinplace(prob.f) ? prob.f.integrand_prototype : prob.f(mid, p)
23-
end
16+
f = prob.f
17+
prototype = Integrals.get_prototype(prob)
2418

25-
@assert eltype(y)<:Real "Cubature.jl is only compatible with real-valued integrands"
19+
@assert eltype(prototype)<:Real "Cubature.jl is only compatible with real-valued integrands"
2620

27-
if prob.f isa BatchIntegralFunction
28-
if y isa AbstractVector # this branch could be omitted since the following one should work similarly
29-
if isinplace(prob)
21+
if f isa BatchIntegralFunction
22+
if prototype isa AbstractVector # this branch could be omitted since the following one should work similarly
23+
if isinplace(f)
3024
# dx is a Vector, but we provide the integrand a vector of the same type as
3125
# y, which needs to be resized since the number of batch points changes.
32-
dy = similar(y)
33-
f = (x, dx) -> begin
34-
resize!(dy, length(dx))
35-
prob.f(dy, x, p)
36-
dx .= dy
26+
_f = let y = similar(prototype)
27+
(u, v) -> begin
28+
resize!(y, length(v))
29+
f(y, u, p)
30+
v .= y
31+
end
3732
end
3833
else
39-
f = (x, dx) -> (dx .= prob.f(x, p))
34+
_f = (u, v) -> (v .= f(u, p))
4035
end
4136
if mid isa Number
4237
if alg isa CubatureJLh
43-
val, err = Cubature.hquadrature_v(f, lb, ub;
38+
val, err = Cubature.hquadrature_v(_f, lb, ub;
4439
reltol = reltol, abstol = abstol,
4540
maxevals = maxiters)
4641
else
47-
val, err = Cubature.pquadrature_v(f, lb, ub;
42+
val, err = Cubature.pquadrature_v(_f, lb, ub;
4843
reltol = reltol, abstol = abstol,
4944
maxevals = maxiters)
5045
end
5146
else
5247
if alg isa CubatureJLh
53-
val, err = Cubature.hcubature_v(f, lb, ub;
48+
val, err = Cubature.hcubature_v(_f, lb, ub;
5449
reltol = reltol, abstol = abstol,
5550
maxevals = maxiters)
5651
else
57-
val, err = Cubature.pcubature_v(f, lb, ub;
52+
val, err = Cubature.pcubature_v(_f, lb, ub;
5853
reltol = reltol, abstol = abstol,
5954
maxevals = maxiters)
6055
end
6156
end
62-
elseif y isa AbstractArray
63-
bfsize = size(y)[begin:(end - 1)]
64-
bfdim = prod(bfsize)
65-
if isinplace(prob)
57+
elseif prototype isa AbstractArray
58+
fsize = size(prototype)[begin:(end - 1)]
59+
fdim = prod(fsize)
60+
if isinplace(f)
6661
# dx is a Matrix, but to provide a buffer of the same type as y, we make
6762
# would like to make views of a larger buffer, but CubatureJL doesn't set
6863
# a hard limit for max_batch, so we allocate a new buffer with the needed size
69-
f = (x, dx) -> begin
70-
dy = similar(y, bfsize..., size(dx, 2))
71-
prob.f(dy, x, p)
72-
dx .= reshape(dy, bfdim, size(dx, 2))
64+
_f = let fsize = fsize
65+
(u, v) -> begin
66+
y = similar(prototype, fsize..., size(v, 2))
67+
f(y, u, p)
68+
v .= reshape(y, fdim, size(v, 2))
69+
end
7370
end
7471
else
75-
f = (x, dx) -> (dx .= reshape(prob.f(x, p), bfdim, size(dx, 2)))
72+
_f = (u, v) -> (v .= reshape(f(u, p), fdim, size(v, 2)))
7673
end
7774
if mid isa Number
7875
if alg isa CubatureJLh
79-
val_, err = Cubature.hquadrature_v(bfdim, f, lb, ub;
76+
val_, err = Cubature.hquadrature_v(fdim, _f, lb, ub;
8077
reltol = reltol, abstol = abstol,
8178
maxevals = maxiters, error_norm = alg.error_norm)
8279
else
83-
val_, err = Cubature.pquadrature_v(bfdim, f, lb, ub;
80+
val_, err = Cubature.pquadrature_v(fdim, _f, lb, ub;
8481
reltol = reltol, abstol = abstol,
8582
maxevals = maxiters, error_norm = alg.error_norm)
8683
end
8784
else
8885
if alg isa CubatureJLh
89-
val_, err = Cubature.hcubature_v(bfdim, f, lb, ub;
86+
val_, err = Cubature.hcubature_v(fdim, _f, lb, ub;
9087
reltol = reltol, abstol = abstol,
9188
maxevals = maxiters, error_norm = alg.error_norm)
9289
else
93-
val_, err = Cubature.pcubature_v(bfdim, f, lb, ub;
90+
val_, err = Cubature.pcubature_v(fdim, _f, lb, ub;
9491
reltol = reltol, abstol = abstol,
9592
maxevals = maxiters, error_norm = alg.error_norm)
9693
end
9794
end
98-
val = reshape(val_, bfsize...)
95+
val = reshape(val_, fsize...)
9996
else
10097
error("BatchIntegralFunction integrands must be arrays for Cubature.jl")
10198
end
10299
else
103-
if y isa Real
100+
if prototype isa Real
104101
# no inplace in this case, since the integrand_prototype would be mutable
105-
f = x -> prob.f(x, p)
102+
_f = u -> f(u, p)
106103
if lb isa Number
107104
if alg isa CubatureJLh
108-
val, err = Cubature.hquadrature(f, lb, ub;
105+
val, err = Cubature.hquadrature(_f, lb, ub;
109106
reltol = reltol, abstol = abstol,
110107
maxevals = maxiters)
111108
else
112-
val, err = Cubature.pquadrature(f, lb, ub;
109+
val, err = Cubature.pquadrature(_f, lb, ub;
113110
reltol = reltol, abstol = abstol,
114111
maxevals = maxiters)
115112
end
116113
else
117114
if alg isa CubatureJLh
118-
val, err = Cubature.hcubature(f, lb, ub;
115+
val, err = Cubature.hcubature(_f, lb, ub;
119116
reltol = reltol, abstol = abstol,
120117
maxevals = maxiters)
121118
else
122-
val, err = Cubature.pcubature(f, lb, ub;
119+
val, err = Cubature.pcubature(_f, lb, ub;
123120
reltol = reltol, abstol = abstol,
124121
maxevals = maxiters)
125122
end
126123
end
127-
elseif y isa AbstractArray
128-
fsize = size(y)
129-
fdim = length(y)
124+
elseif prototype isa AbstractArray
125+
fsize = size(prototype)
126+
fdim = length(prototype)
130127
if isinplace(prob)
131-
dy = similar(y)
132-
f = (x, v) -> (prob.f(dy, x, p); v .= vec(dy))
128+
_f = let y = similar(prototype)
129+
(u, v) -> (f(y, u, p); v .= vec(y))
130+
end
133131
else
134-
f = (x, v) -> (v .= vec(prob.f(x, p)))
132+
_f = (u, v) -> (v .= vec(f(u, p)))
135133
end
136134
if mid isa Number
137135
if alg isa CubatureJLh
138-
val_, err = Cubature.hquadrature(fdim, f, lb, ub;
136+
val_, err = Cubature.hquadrature(fdim, _f, lb, ub;
139137
reltol = reltol, abstol = abstol,
140138
maxevals = maxiters, error_norm = alg.error_norm)
141139
else
142-
val_, err = Cubature.pquadrature(fdim, f, lb, ub;
140+
val_, err = Cubature.pquadrature(fdim, _f, lb, ub;
143141
reltol = reltol, abstol = abstol,
144142
maxevals = maxiters, error_norm = alg.error_norm)
145143
end
146144
else
147145
if alg isa CubatureJLh
148-
val_, err = Cubature.hcubature(fdim, f, lb, ub;
146+
val_, err = Cubature.hcubature(fdim, _f, lb, ub;
149147
reltol = reltol, abstol = abstol,
150148
maxevals = maxiters, error_norm = alg.error_norm)
151149
else
152-
val_, err = Cubature.pcubature(fdim, f, lb, ub;
150+
val_, err = Cubature.pcubature(fdim, _f, lb, ub;
153151
reltol = reltol, abstol = abstol,
154152
maxevals = maxiters, error_norm = alg.error_norm)
155153
end

0 commit comments

Comments
 (0)