Skip to content

Commit 3786b8d

Browse files
committed
Benchmark radon model function
1 parent 78125ed commit 3786b8d

File tree

5 files changed

+293
-3
lines changed

5 files changed

+293
-3
lines changed

conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import pytest
44

55

6+
pytest_plugins = ["tests.shared_fixtures"]
7+
8+
69
def pytest_sessionstart(session):
710
os.environ["PYTENSOR_FLAGS"] = ",".join(
811
[

pytensor/compile/mode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
465465
C_VM = Mode("cvm", "fast_run")
466466

467467
NUMBA = Mode(
468-
NumbaLinker(),
468+
"numba",
469469
RewriteDatabaseQuery(
470470
include=["fast_run", "numba"],
471471
exclude=[
@@ -478,7 +478,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
478478
)
479479

480480
JAX = Mode(
481-
JAXLinker(),
481+
"jax",
482482
RewriteDatabaseQuery(
483483
include=["fast_run", "jax"],
484484
exclude=[
@@ -494,7 +494,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
494494
),
495495
)
496496
PYTORCH = Mode(
497-
PytorchLinker(),
497+
"pytorch",
498498
RewriteDatabaseQuery(
499499
include=["fast_run"],
500500
exclude=[

tests/compile/function/test_types.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,3 +1357,67 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
13571357

13581358
rng_val = np.random.default_rng()
13591359
benchmark(f, rng_val)
1360+
1361+
1362+
@pytest.mark.parametrize("mode", ["C", "C_VM"])
1363+
def test_radon_model_compile_repeatedly_benchmark(mode, radon_model, benchmark):
1364+
joined_inputs, [model_logp, model_dlogp] = radon_model
1365+
rng = np.random.default_rng(1)
1366+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1367+
1368+
def compile_and_call_once():
1369+
fn = function(
1370+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1371+
)
1372+
fn(x)
1373+
1374+
benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1)
1375+
1376+
1377+
@pytest.mark.parametrize("mode", ["C", "C_VM"])
1378+
def test_radon_model_compile_variants_benchmark(
1379+
mode, radon_model, radon_model_variants, benchmark
1380+
):
1381+
"""Test compilation speed when a slightly variant of a function is compiled each time.
1382+
1383+
This test more realistically simulates a use case where a model is recompiled
1384+
multiple times with small changes, such as in an interactive environment.
1385+
1386+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
1387+
"""
1388+
joined_inputs, [model_logp, model_dlogp] = radon_model
1389+
rng = np.random.default_rng(1)
1390+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1391+
1392+
# Compile base function once to populate the cache
1393+
fn = function(
1394+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1395+
)
1396+
fn(x)
1397+
1398+
def compile_and_call_once():
1399+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
1400+
fn = function(
1401+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1402+
)
1403+
fn(x)
1404+
1405+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
1406+
1407+
1408+
@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC"])
1409+
def test_radon_model_call_benchmark(mode, radon_model, benchmark):
1410+
joined_inputs, [model_logp, model_dlogp] = radon_model
1411+
1412+
real_mode = "C_VM" if mode == "C_VM_NOGC" else mode
1413+
fn = function(
1414+
[joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True
1415+
)
1416+
if mode == "C_VM_NOGC":
1417+
fn.vm.allow_gc = False
1418+
1419+
rng = np.random.default_rng(1)
1420+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1421+
fn(x) # warmup
1422+
1423+
benchmark(fn, x)

tests/link/numba/test_performance.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,70 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
7575

7676
# FIXME: Why are we asserting >=? Numba could be doing worse than numpy!
7777
assert mean_numba_time / mean_numpy_time >= 0.75
78+
79+
80+
@pytest.mark.parametrize("cache", (False, True))
81+
def test_radon_model_compile_repeatedly_numba_benchmark(cache, radon_model, benchmark):
82+
joined_inputs, [model_logp, model_dlogp] = radon_model
83+
rng = np.random.default_rng(1)
84+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
85+
86+
def compile_and_call_once():
87+
with config.change_flags(numba__cache=cache):
88+
fn = function(
89+
[joined_inputs],
90+
[model_logp, model_dlogp],
91+
mode="NUMBA",
92+
trust_input=True,
93+
)
94+
fn(x)
95+
96+
benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1)
97+
98+
99+
@pytest.mark.parametrize("cache", (False, True))
100+
def test_radon_model_compile_variants_numba_benchmark(
101+
cache, radon_model, radon_model_variants, benchmark
102+
):
103+
"""Test compilation speed when a slightly variant of a function is compiled each time.
104+
105+
This test more realistically simulates a use case where a model is recompiled
106+
multiple times with small changes, such as in an interactive environment.
107+
108+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
109+
"""
110+
joined_inputs, [model_logp, model_dlogp] = radon_model
111+
rng = np.random.default_rng(1)
112+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
113+
114+
# Compile base function once to populate the cache
115+
fn = function(
116+
[joined_inputs], [model_logp, model_dlogp], mode="NUMBA", trust_input=True
117+
)
118+
fn(x)
119+
120+
def compile_and_call_once():
121+
with config.change_flags(numba__cache=cache):
122+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
123+
fn = function(
124+
[joined_inputs],
125+
[model_logp, model_dlogp],
126+
mode="NUMBA",
127+
trust_input=True,
128+
)
129+
fn(x)
130+
131+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
132+
133+
134+
def test_radon_model_call_numba_benchmark(radon_model, benchmark):
135+
joined_inputs, [model_logp, model_dlogp] = radon_model
136+
137+
fn = function(
138+
[joined_inputs], [model_logp, model_dlogp], mode="NUMBA", trust_input=True
139+
)
140+
rng = np.random.default_rng(1)
141+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
142+
fn(x) # warmup
143+
144+
benchmark(fn, x)

tests/shared_fixtures.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor.graph.replace import graph_replace
6+
from pytensor.graph.rewriting import rewrite_graph
7+
from pytensor.graph.traversal import explicit_graph_inputs
8+
9+
10+
def create_radon_model(
11+
intercept_dist="normal", sigma_dist="halfnormal", centered=False
12+
):
13+
def halfnormal(name, *, sigma=1.0, model_logp):
14+
log_value = pt.scalar(f"{name}_log")
15+
value = pt.exp(log_value)
16+
17+
logp = (
18+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
19+
)
20+
logp = pt.switch(value >= 0, logp, -np.inf)
21+
model_logp.append(logp + value)
22+
return value
23+
24+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
25+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
26+
27+
logp = (
28+
-0.5 * (((value - mu) / sigma) ** 2)
29+
- pt.log(pt.sqrt(2.0 * np.pi))
30+
- pt.log(sigma)
31+
)
32+
model_logp.append(logp)
33+
return value
34+
35+
def lognormal(name, *, mu=0.0, sigma=1.0, model_logp):
36+
value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp)
37+
return pt.exp(value)
38+
39+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
40+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
41+
n = raw_value.shape[0] + 1
42+
sum_vals = raw_value.sum(0, keepdims=True)
43+
norm = sum_vals / (pt.sqrt(n) + n)
44+
fill_value = norm - sum_vals / pt.sqrt(n)
45+
value = pt.concatenate([raw_value, fill_value]) - norm
46+
47+
shape = value.shape
48+
_full_size = pt.prod(shape)
49+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
50+
logp = pt.sum(
51+
-0.5 * ((value / sigma) ** 2)
52+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
53+
* (_degrees_of_freedom / _full_size)
54+
)
55+
model_logp.append(logp)
56+
return value
57+
58+
dist_fn_map = {
59+
fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal)
60+
}
61+
62+
rng = np.random.default_rng(1)
63+
n_counties = 85
64+
county_idx = rng.integers(n_counties, size=919)
65+
county_idx.sort()
66+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
67+
log_radon = rng.normal(size=919)
68+
69+
model_logp = []
70+
intercept = dist_fn_map[intercept_dist](
71+
"intercept", sigma=10, model_logp=model_logp
72+
)
73+
74+
# County effects
75+
county_sd = halfnormal("county_sd", model_logp=model_logp)
76+
if centered:
77+
county_effect = zerosumnormal(
78+
"county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp
79+
)
80+
else:
81+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
82+
county_effect = county_raw * county_sd
83+
84+
# Global floor effect
85+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
86+
87+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
88+
if centered:
89+
county_floor_effect = zerosumnormal(
90+
"county_floor_raw",
91+
sigma=county_floor_sd,
92+
size=n_counties,
93+
model_logp=model_logp,
94+
)
95+
else:
96+
county_floor_raw = zerosumnormal(
97+
"county_floor_raw", size=n_counties, model_logp=model_logp
98+
)
99+
county_floor_effect = county_floor_raw * county_floor_sd
100+
101+
mu = (
102+
intercept
103+
+ county_effect[county_idx]
104+
+ floor_effect * floor
105+
+ county_floor_effect[county_idx] * floor
106+
)
107+
108+
sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp)
109+
_ = normal(
110+
"log_radon",
111+
mu=mu,
112+
sigma=sigma,
113+
observed=log_radon,
114+
model_logp=model_logp,
115+
)
116+
117+
model_logp = pt.sum([logp.sum() for logp in model_logp])
118+
model_logp = rewrite_graph(
119+
model_logp, include=("canonicalize", "stabilize"), clone=False
120+
)
121+
params = list(explicit_graph_inputs(model_logp))
122+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
123+
124+
size = sum(int(np.prod(p.type.shape)) for p in params)
125+
joined_inputs = pt.vector("joined_inputs", shape=(size,))
126+
idx = 0
127+
replacement = {}
128+
for param in params:
129+
param_shape = param.type.shape
130+
param_size = int(np.prod(param_shape))
131+
replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape)
132+
idx += param_size
133+
assert idx == joined_inputs.type.shape[0]
134+
135+
model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement)
136+
return joined_inputs, [model_logp, model_dlogp]
137+
138+
139+
@pytest.fixture(scope="session")
140+
def radon_model():
141+
return create_radon_model()
142+
143+
144+
@pytest.fixture(scope="session")
145+
def radon_model_variants():
146+
# Convert to list comp
147+
return [
148+
create_radon_model(
149+
intercept_dist=intercept_dist,
150+
sigma_dist=sigma_dist,
151+
centered=centered,
152+
)
153+
for centered in (True, False)
154+
for intercept_dist in ("normal", "lognormal")
155+
for sigma_dist in ("halfnormal", "lognormal")
156+
]

0 commit comments

Comments
 (0)