Skip to content

Commit c722e9f

Browse files
committed
Benchmark radon model function
1 parent a553050 commit c722e9f

File tree

2 files changed

+217
-3
lines changed

2 files changed

+217
-3
lines changed

pytensor/compile/mode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
465465
C_VM = Mode("cvm", "fast_run")
466466

467467
NUMBA = Mode(
468-
NumbaLinker(),
468+
"numba",
469469
RewriteDatabaseQuery(
470470
include=["fast_run", "numba"],
471471
exclude=[
@@ -478,7 +478,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
478478
)
479479

480480
JAX = Mode(
481-
JAXLinker(),
481+
"jax",
482482
RewriteDatabaseQuery(
483483
include=["fast_run", "jax"],
484484
exclude=[
@@ -494,7 +494,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
494494
),
495495
)
496496
PYTORCH = Mode(
497-
PytorchLinker(),
497+
"pytorch",
498498
RewriteDatabaseQuery(
499499
include=["fast_run"],
500500
exclude=[

tests/compile/function/test_types.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from pytensor.compile.mode import Mode, get_default_mode
1414
from pytensor.configdefaults import config
1515
from pytensor.graph.basic import Constant
16+
from pytensor.graph.replace import graph_replace
17+
from pytensor.graph.rewriting import rewrite_graph
1618
from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
19+
from pytensor.graph.traversal import explicit_graph_inputs
1720
from pytensor.graph.utils import MissingInputError
1821
from pytensor.link.vm import VMLinker
1922
from pytensor.printing import debugprint
@@ -1357,3 +1360,214 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
13571360

13581361
rng_val = np.random.default_rng()
13591362
benchmark(f, rng_val)
1363+
1364+
1365+
def create_radon_model(
1366+
intercept_dist="normal", sigma_dist="halfnormal", centered=False
1367+
):
1368+
def halfnormal(name, *, sigma=1.0, model_logp):
1369+
log_value = pt.scalar(f"{name}_log")
1370+
value = pt.exp(log_value)
1371+
1372+
logp = (
1373+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
1374+
)
1375+
logp = pt.switch(value >= 0, logp, -np.inf)
1376+
model_logp.append(logp + value)
1377+
return value
1378+
1379+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
1380+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
1381+
1382+
logp = (
1383+
-0.5 * (((value - mu) / sigma) ** 2)
1384+
- pt.log(pt.sqrt(2.0 * np.pi))
1385+
- pt.log(sigma)
1386+
)
1387+
model_logp.append(logp)
1388+
return value
1389+
1390+
def lognormal(name, *, mu=0.0, sigma=1.0, model_logp):
1391+
value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp)
1392+
return pt.exp(value)
1393+
1394+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
1395+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
1396+
n = raw_value.shape[0] + 1
1397+
sum_vals = raw_value.sum(0, keepdims=True)
1398+
norm = sum_vals / (pt.sqrt(n) + n)
1399+
fill_value = norm - sum_vals / pt.sqrt(n)
1400+
value = pt.concatenate([raw_value, fill_value]) - norm
1401+
1402+
shape = value.shape
1403+
_full_size = pt.prod(shape)
1404+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
1405+
logp = pt.sum(
1406+
-0.5 * ((value / sigma) ** 2)
1407+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
1408+
* (_degrees_of_freedom / _full_size)
1409+
)
1410+
model_logp.append(logp)
1411+
return value
1412+
1413+
dist_fn_map = {
1414+
fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal)
1415+
}
1416+
1417+
rng = np.random.default_rng(1)
1418+
n_counties = 85
1419+
county_idx = rng.integers(n_counties, size=919)
1420+
county_idx.sort()
1421+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
1422+
log_radon = rng.normal(size=919)
1423+
1424+
model_logp = []
1425+
intercept = dist_fn_map[intercept_dist](
1426+
"intercept", sigma=10, model_logp=model_logp
1427+
)
1428+
1429+
# County effects
1430+
county_sd = halfnormal("county_sd", model_logp=model_logp)
1431+
if centered:
1432+
county_effect = zerosumnormal(
1433+
"county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp
1434+
)
1435+
else:
1436+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
1437+
county_effect = county_raw * county_sd
1438+
1439+
# Global floor effect
1440+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
1441+
1442+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
1443+
if centered:
1444+
county_floor_effect = zerosumnormal(
1445+
"county_floor_raw",
1446+
sigma=county_floor_sd,
1447+
size=n_counties,
1448+
model_logp=model_logp,
1449+
)
1450+
else:
1451+
county_floor_raw = zerosumnormal(
1452+
"county_floor_raw", size=n_counties, model_logp=model_logp
1453+
)
1454+
county_floor_effect = county_floor_raw * county_floor_sd
1455+
1456+
mu = (
1457+
intercept
1458+
+ county_effect[county_idx]
1459+
+ floor_effect * floor
1460+
+ county_floor_effect[county_idx] * floor
1461+
)
1462+
1463+
sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp)
1464+
_ = normal(
1465+
"log_radon",
1466+
mu=mu,
1467+
sigma=sigma,
1468+
observed=log_radon,
1469+
model_logp=model_logp,
1470+
)
1471+
1472+
model_logp = pt.sum([logp.sum() for logp in model_logp])
1473+
model_logp = rewrite_graph(
1474+
model_logp, include=("canonicalize", "stabilize"), clone=False
1475+
)
1476+
params = list(explicit_graph_inputs(model_logp))
1477+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
1478+
1479+
size = sum(int(np.prod(p.type.shape)) for p in params)
1480+
joined_inputs = pt.vector("joined_inputs", shape=(size,))
1481+
idx = 0
1482+
replacement = {}
1483+
for param in params:
1484+
param_shape = param.type.shape
1485+
param_size = int(np.prod(param_shape))
1486+
replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape)
1487+
idx += param_size
1488+
assert idx == joined_inputs.type.shape[0]
1489+
1490+
model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement)
1491+
return joined_inputs, [model_logp, model_dlogp]
1492+
1493+
1494+
@pytest.fixture(scope="session")
1495+
def radon_model():
1496+
return create_radon_model()
1497+
1498+
1499+
@pytest.fixture(scope="session")
1500+
def radon_model_variants():
1501+
# Convert to list comp
1502+
return [
1503+
create_radon_model(
1504+
intercept_dist=intercept_dist,
1505+
sigma_dist=sigma_dist,
1506+
centered=centered,
1507+
)
1508+
for centered in (True, False)
1509+
for intercept_dist in ("normal", "lognormal")
1510+
for sigma_dist in ("halfnormal", "lognormal")
1511+
]
1512+
1513+
1514+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"])
1515+
def test_radon_model_compile_repeatedly_benchmark(mode, radon_model, benchmark):
1516+
joined_inputs, [model_logp, model_dlogp] = radon_model
1517+
rng = np.random.default_rng(1)
1518+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1519+
1520+
def compile_and_call_once():
1521+
fn = function(
1522+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1523+
)
1524+
fn(x)
1525+
1526+
benchmark(compile_and_call_once)
1527+
1528+
1529+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"])
1530+
def test_radon_model_compile_variants_benchmark(
1531+
mode, radon_model, radon_model_variants, benchmark
1532+
):
1533+
"""Test compilation speed when a slightly variant of a function is compiled each time.
1534+
1535+
This test more realistically simulates a use case where a model is recompiled
1536+
multiple times with small changes, such as in an interactive environment.
1537+
1538+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
1539+
"""
1540+
joined_inputs, [model_logp, model_dlogp] = radon_model
1541+
rng = np.random.default_rng(1)
1542+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1543+
1544+
# Compile base function once to populate the cache
1545+
fn = function(
1546+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1547+
)
1548+
fn(x)
1549+
1550+
def compile_and_call_once():
1551+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
1552+
fn = function(
1553+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1554+
)
1555+
fn(x)
1556+
1557+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
1558+
1559+
1560+
@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"])
1561+
def test_radon_model_call_benchmark(mode, radon_model, benchmark):
1562+
joined_inputs, [model_logp, model_dlogp] = radon_model
1563+
1564+
real_mode = "C_VM" if mode == "C_VM_NOGC" else mode
1565+
fn = function(
1566+
[joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True
1567+
)
1568+
if mode == "C_VM_NOGC":
1569+
fn.vm.allow_gc = False
1570+
1571+
rng = np.random.default_rng(1)
1572+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1573+
benchmark(fn, x)

0 commit comments

Comments
 (0)