Skip to content

Commit 641ee23

Browse files
committed
Benchmark radon model function
1 parent 78125ed commit 641ee23

File tree

5 files changed

+294
-1
lines changed

5 files changed

+294
-1
lines changed

conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import pytest
44

55

6+
# Using pytest_plugins causes `tests/link/c/test_cmodule.py::test_cache_versioning` to fail
7+
# pytest_plugins = ["tests.fixtures"]
8+
9+
610
def pytest_sessionstart(session):
711
os.environ["PYTENSOR_FLAGS"] = ",".join(
812
[

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ lines-after-imports = 2
150150
"pytensor/misc/check_duplicate_key.py" = ["T201"]
151151
"pytensor/misc/check_blas.py" = ["T201"]
152152
"pytensor/bin/pytensor_cache.py" = ["T201"]
153-
# For the tests we skip because `pytest.importorskip` is used:
153+
# For the tests we skip `E402` because `pytest.importorskip` is used:
154154
"tests/link/jax/test_scalar.py" = ["E402"]
155155
"tests/link/jax/test_tensor_basic.py" = ["E402"]
156156
"tests/link/numba/test_basic.py" = ["E402"]

tests/compile/function/test_types.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
scalars,
3434
vector,
3535
)
36+
from tests.fixtures import * # noqa: F403
3637

3738

3839
pytestmark = pytest.mark.filterwarnings("error")
@@ -1357,3 +1358,67 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
13571358

13581359
rng_val = np.random.default_rng()
13591360
benchmark(f, rng_val)
1361+
1362+
1363+
@pytest.mark.parametrize("mode", ["C", "C_VM"])
1364+
def test_radon_model_compile_repeatedly_benchmark(mode, radon_model, benchmark):
1365+
joined_inputs, [model_logp, model_dlogp] = radon_model
1366+
rng = np.random.default_rng(1)
1367+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1368+
1369+
def compile_and_call_once():
1370+
fn = function(
1371+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1372+
)
1373+
fn(x)
1374+
1375+
benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1)
1376+
1377+
1378+
@pytest.mark.parametrize("mode", ["C", "C_VM"])
1379+
def test_radon_model_compile_variants_benchmark(
1380+
mode, radon_model, radon_model_variants, benchmark
1381+
):
1382+
"""Test compilation speed when a slightly variant of a function is compiled each time.
1383+
1384+
This test more realistically simulates a use case where a model is recompiled
1385+
multiple times with small changes, such as in an interactive environment.
1386+
1387+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
1388+
"""
1389+
joined_inputs, [model_logp, model_dlogp] = radon_model
1390+
rng = np.random.default_rng(1)
1391+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1392+
1393+
# Compile base function once to populate the cache
1394+
fn = function(
1395+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1396+
)
1397+
fn(x)
1398+
1399+
def compile_and_call_once():
1400+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
1401+
fn = function(
1402+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1403+
)
1404+
fn(x)
1405+
1406+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
1407+
1408+
1409+
@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC"])
1410+
def test_radon_model_call_benchmark(mode, radon_model, benchmark):
1411+
joined_inputs, [model_logp, model_dlogp] = radon_model
1412+
1413+
real_mode = "C_VM" if mode == "C_VM_NOGC" else mode
1414+
fn = function(
1415+
[joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True
1416+
)
1417+
if mode == "C_VM_NOGC":
1418+
fn.vm.allow_gc = False
1419+
1420+
rng = np.random.default_rng(1)
1421+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1422+
fn(x) # warmup
1423+
1424+
benchmark(fn, x)

tests/fixtures.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor.graph.replace import graph_replace
6+
from pytensor.graph.rewriting import rewrite_graph
7+
from pytensor.graph.traversal import explicit_graph_inputs
8+
9+
10+
def create_radon_model(
11+
intercept_dist="normal", sigma_dist="halfnormal", centered=False
12+
):
13+
def halfnormal(name, *, sigma=1.0, model_logp):
14+
log_value = pt.scalar(f"{name}_log")
15+
value = pt.exp(log_value)
16+
17+
logp = (
18+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
19+
)
20+
logp = pt.switch(value >= 0, logp, -np.inf)
21+
model_logp.append(logp + value)
22+
return value
23+
24+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
25+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
26+
27+
logp = (
28+
-0.5 * (((value - mu) / sigma) ** 2)
29+
- pt.log(pt.sqrt(2.0 * np.pi))
30+
- pt.log(sigma)
31+
)
32+
model_logp.append(logp)
33+
return value
34+
35+
def lognormal(name, *, mu=0.0, sigma=1.0, model_logp):
36+
value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp)
37+
return pt.exp(value)
38+
39+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
40+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
41+
n = raw_value.shape[0] + 1
42+
sum_vals = raw_value.sum(0, keepdims=True)
43+
norm = sum_vals / (pt.sqrt(n) + n)
44+
fill_value = norm - sum_vals / pt.sqrt(n)
45+
value = pt.concatenate([raw_value, fill_value]) - norm
46+
47+
shape = value.shape
48+
_full_size = pt.prod(shape)
49+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
50+
logp = pt.sum(
51+
-0.5 * ((value / sigma) ** 2)
52+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
53+
* (_degrees_of_freedom / _full_size)
54+
)
55+
model_logp.append(logp)
56+
return value
57+
58+
dist_fn_map = {
59+
fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal)
60+
}
61+
62+
rng = np.random.default_rng(1)
63+
n_counties = 85
64+
county_idx = rng.integers(n_counties, size=919)
65+
county_idx.sort()
66+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
67+
log_radon = rng.normal(size=919)
68+
69+
model_logp = []
70+
intercept = dist_fn_map[intercept_dist](
71+
"intercept", sigma=10, model_logp=model_logp
72+
)
73+
74+
# County effects
75+
county_sd = halfnormal("county_sd", model_logp=model_logp)
76+
if centered:
77+
county_effect = zerosumnormal(
78+
"county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp
79+
)
80+
else:
81+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
82+
county_effect = county_raw * county_sd
83+
84+
# Global floor effect
85+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
86+
87+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
88+
if centered:
89+
county_floor_effect = zerosumnormal(
90+
"county_floor_raw",
91+
sigma=county_floor_sd,
92+
size=n_counties,
93+
model_logp=model_logp,
94+
)
95+
else:
96+
county_floor_raw = zerosumnormal(
97+
"county_floor_raw", size=n_counties, model_logp=model_logp
98+
)
99+
county_floor_effect = county_floor_raw * county_floor_sd
100+
101+
mu = (
102+
intercept
103+
+ county_effect[county_idx]
104+
+ floor_effect * floor
105+
+ county_floor_effect[county_idx] * floor
106+
)
107+
108+
sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp)
109+
_ = normal(
110+
"log_radon",
111+
mu=mu,
112+
sigma=sigma,
113+
observed=log_radon,
114+
model_logp=model_logp,
115+
)
116+
117+
model_logp = pt.sum([logp.sum() for logp in model_logp])
118+
model_logp = rewrite_graph(
119+
model_logp, include=("canonicalize", "stabilize"), clone=False
120+
)
121+
params = list(explicit_graph_inputs(model_logp))
122+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
123+
124+
size = sum(int(np.prod(p.type.shape)) for p in params)
125+
joined_inputs = pt.vector("joined_inputs", shape=(size,))
126+
idx = 0
127+
replacement = {}
128+
for param in params:
129+
param_shape = param.type.shape
130+
param_size = int(np.prod(param_shape))
131+
replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape)
132+
idx += param_size
133+
assert idx == joined_inputs.type.shape[0]
134+
135+
model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement)
136+
return joined_inputs, [model_logp, model_dlogp]
137+
138+
139+
@pytest.fixture(scope="session")
140+
def radon_model():
141+
return create_radon_model()
142+
143+
144+
@pytest.fixture(scope="session")
145+
def radon_model_variants():
146+
# Convert to list comp
147+
return [
148+
create_radon_model(
149+
intercept_dist=intercept_dist,
150+
sigma_dist=sigma_dist,
151+
centered=centered,
152+
)
153+
for centered in (True, False)
154+
for intercept_dist in ("normal", "lognormal")
155+
for sigma_dist in ("halfnormal", "lognormal")
156+
]

tests/link/numba/test_performance.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1414
from pytensor.link.numba.linker import NumbaLinker
1515
from pytensor.tensor.math import Max
16+
from tests.fixtures import * # noqa: F403
1617

1718

1819
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
@@ -75,3 +76,70 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
7576

7677
# FIXME: Why are we asserting >=? Numba could be doing worse than numpy!
7778
assert mean_numba_time / mean_numpy_time >= 0.75
79+
80+
81+
@pytest.mark.parametrize("cache", (False, True))
82+
def test_radon_model_compile_repeatedly_numba_benchmark(cache, radon_model, benchmark):
83+
joined_inputs, [model_logp, model_dlogp] = radon_model
84+
rng = np.random.default_rng(1)
85+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
86+
87+
def compile_and_call_once():
88+
with config.change_flags(numba__cache=cache):
89+
fn = function(
90+
[joined_inputs],
91+
[model_logp, model_dlogp],
92+
mode="NUMBA",
93+
trust_input=True,
94+
)
95+
fn(x)
96+
97+
benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1)
98+
99+
100+
@pytest.mark.parametrize("cache", (False, True))
101+
def test_radon_model_compile_variants_numba_benchmark(
102+
cache, radon_model, radon_model_variants, benchmark
103+
):
104+
"""Test compilation speed when a slightly variant of a function is compiled each time.
105+
106+
This test more realistically simulates a use case where a model is recompiled
107+
multiple times with small changes, such as in an interactive environment.
108+
109+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
110+
"""
111+
joined_inputs, [model_logp, model_dlogp] = radon_model
112+
rng = np.random.default_rng(1)
113+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
114+
115+
# Compile base function once to populate the cache
116+
fn = function(
117+
[joined_inputs], [model_logp, model_dlogp], mode="NUMBA", trust_input=True
118+
)
119+
fn(x)
120+
121+
def compile_and_call_once():
122+
with config.change_flags(numba__cache=cache):
123+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
124+
fn = function(
125+
[joined_inputs],
126+
[model_logp, model_dlogp],
127+
mode="NUMBA",
128+
trust_input=True,
129+
)
130+
fn(x)
131+
132+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
133+
134+
135+
def test_radon_model_call_numba_benchmark(radon_model, benchmark):
136+
joined_inputs, [model_logp, model_dlogp] = radon_model
137+
138+
fn = function(
139+
[joined_inputs], [model_logp, model_dlogp], mode="NUMBA", trust_input=True
140+
)
141+
rng = np.random.default_rng(1)
142+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
143+
fn(x) # warmup
144+
145+
benchmark(fn, x)

0 commit comments

Comments
 (0)