|
13 | 13 | from pytensor.compile.mode import Mode, get_default_mode |
14 | 14 | from pytensor.configdefaults import config |
15 | 15 | from pytensor.graph.basic import Constant |
| 16 | +from pytensor.graph.replace import graph_replace |
| 17 | +from pytensor.graph.rewriting import rewrite_graph |
16 | 18 | from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter |
| 19 | +from pytensor.graph.traversal import explicit_graph_inputs |
17 | 20 | from pytensor.graph.utils import MissingInputError |
18 | 21 | from pytensor.link.vm import VMLinker |
19 | 22 | from pytensor.printing import debugprint |
@@ -1357,3 +1360,216 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): |
1357 | 1360 |
|
1358 | 1361 | rng_val = np.random.default_rng() |
1359 | 1362 | benchmark(f, rng_val) |
| 1363 | + |
| 1364 | + |
| 1365 | +def create_radon_model( |
| 1366 | + intercept_dist="normal", sigma_dist="halfnormal", centered=False |
| 1367 | +): |
| 1368 | + def halfnormal(name, *, sigma=1.0, model_logp): |
| 1369 | + log_value = pt.scalar(f"{name}_log") |
| 1370 | + value = pt.exp(log_value) |
| 1371 | + |
| 1372 | + logp = ( |
| 1373 | + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) |
| 1374 | + ) |
| 1375 | + logp = pt.switch(value >= 0, logp, -np.inf) |
| 1376 | + model_logp.append(logp + value) |
| 1377 | + return value |
| 1378 | + |
| 1379 | + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): |
| 1380 | + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) |
| 1381 | + |
| 1382 | + logp = ( |
| 1383 | + -0.5 * (((value - mu) / sigma) ** 2) |
| 1384 | + - pt.log(pt.sqrt(2.0 * np.pi)) |
| 1385 | + - pt.log(sigma) |
| 1386 | + ) |
| 1387 | + model_logp.append(logp) |
| 1388 | + return value |
| 1389 | + |
| 1390 | + def lognormal(name, *, mu=0.0, sigma=1.0, model_logp): |
| 1391 | + value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp) |
| 1392 | + return pt.exp(value) |
| 1393 | + |
| 1394 | + def zerosumnormal(name, *, sigma=1.0, size, model_logp): |
| 1395 | + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) |
| 1396 | + n = raw_value.shape[0] + 1 |
| 1397 | + sum_vals = raw_value.sum(0, keepdims=True) |
| 1398 | + norm = sum_vals / (pt.sqrt(n) + n) |
| 1399 | + fill_value = norm - sum_vals / pt.sqrt(n) |
| 1400 | + value = pt.concatenate([raw_value, fill_value]) - norm |
| 1401 | + |
| 1402 | + shape = value.shape |
| 1403 | + _full_size = pt.prod(shape) |
| 1404 | + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) |
| 1405 | + logp = pt.sum( |
| 1406 | + -0.5 * ((value / sigma) ** 2) |
| 1407 | + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) |
| 1408 | + * (_degrees_of_freedom / _full_size) |
| 1409 | + ) |
| 1410 | + model_logp.append(logp) |
| 1411 | + return value |
| 1412 | + |
| 1413 | + dist_fn_map = { |
| 1414 | + fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal) |
| 1415 | + } |
| 1416 | + |
| 1417 | + rng = np.random.default_rng(1) |
| 1418 | + n_counties = 85 |
| 1419 | + county_idx = rng.integers(n_counties, size=919) |
| 1420 | + county_idx.sort() |
| 1421 | + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) |
| 1422 | + log_radon = rng.normal(size=919) |
| 1423 | + |
| 1424 | + model_logp = [] |
| 1425 | + intercept = dist_fn_map[intercept_dist]( |
| 1426 | + "intercept", sigma=10, model_logp=model_logp |
| 1427 | + ) |
| 1428 | + |
| 1429 | + # County effects |
| 1430 | + county_sd = halfnormal("county_sd", model_logp=model_logp) |
| 1431 | + if centered: |
| 1432 | + county_effect = zerosumnormal( |
| 1433 | + "county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp |
| 1434 | + ) |
| 1435 | + else: |
| 1436 | + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) |
| 1437 | + county_effect = county_raw * county_sd |
| 1438 | + |
| 1439 | + # Global floor effect |
| 1440 | + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) |
| 1441 | + |
| 1442 | + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) |
| 1443 | + if centered: |
| 1444 | + county_floor_effect = zerosumnormal( |
| 1445 | + "county_floor_raw", |
| 1446 | + sigma=county_floor_sd, |
| 1447 | + size=n_counties, |
| 1448 | + model_logp=model_logp, |
| 1449 | + ) |
| 1450 | + else: |
| 1451 | + county_floor_raw = zerosumnormal( |
| 1452 | + "county_floor_raw", size=n_counties, model_logp=model_logp |
| 1453 | + ) |
| 1454 | + county_floor_effect = county_floor_raw * county_floor_sd |
| 1455 | + |
| 1456 | + mu = ( |
| 1457 | + intercept |
| 1458 | + + county_effect[county_idx] |
| 1459 | + + floor_effect * floor |
| 1460 | + + county_floor_effect[county_idx] * floor |
| 1461 | + ) |
| 1462 | + |
| 1463 | + sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp) |
| 1464 | + _ = normal( |
| 1465 | + "log_radon", |
| 1466 | + mu=mu, |
| 1467 | + sigma=sigma, |
| 1468 | + observed=log_radon, |
| 1469 | + model_logp=model_logp, |
| 1470 | + ) |
| 1471 | + |
| 1472 | + model_logp = pt.sum([logp.sum() for logp in model_logp]) |
| 1473 | + model_logp = rewrite_graph( |
| 1474 | + model_logp, include=("canonicalize", "stabilize"), clone=False |
| 1475 | + ) |
| 1476 | + params = list(explicit_graph_inputs(model_logp)) |
| 1477 | + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) |
| 1478 | + |
| 1479 | + size = sum(int(np.prod(p.type.shape)) for p in params) |
| 1480 | + joined_inputs = pt.vector("joined_inputs", shape=(size,)) |
| 1481 | + idx = 0 |
| 1482 | + replacement = {} |
| 1483 | + for param in params: |
| 1484 | + param_shape = param.type.shape |
| 1485 | + param_size = int(np.prod(param_shape)) |
| 1486 | + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) |
| 1487 | + idx += param_size |
| 1488 | + assert idx == joined_inputs.type.shape[0] |
| 1489 | + |
| 1490 | + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) |
| 1491 | + return joined_inputs, [model_logp, model_dlogp] |
| 1492 | + |
| 1493 | + |
| 1494 | +@pytest.fixture(scope="session") |
| 1495 | +def radon_model(): |
| 1496 | + return create_radon_model() |
| 1497 | + |
| 1498 | + |
| 1499 | +@pytest.fixture(scope="session") |
| 1500 | +def radon_model_variants(): |
| 1501 | + # Convert to list comp |
| 1502 | + return [ |
| 1503 | + create_radon_model( |
| 1504 | + intercept_dist=intercept_dist, |
| 1505 | + sigma_dist=sigma_dist, |
| 1506 | + centered=centered, |
| 1507 | + ) |
| 1508 | + for centered in (True, False) |
| 1509 | + for intercept_dist in ("normal", "lognormal") |
| 1510 | + for sigma_dist in ("halfnormal", "lognormal") |
| 1511 | + ] |
| 1512 | + |
| 1513 | + |
| 1514 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"]) |
| 1515 | +def test_radon_model_compile_repeatedly_benchmark(mode, radon_model, benchmark): |
| 1516 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1517 | + rng = np.random.default_rng(1) |
| 1518 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1519 | + |
| 1520 | + def compile_and_call_once(): |
| 1521 | + fn = function( |
| 1522 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1523 | + ) |
| 1524 | + fn(x) |
| 1525 | + |
| 1526 | + benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1) |
| 1527 | + |
| 1528 | + |
| 1529 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"]) |
| 1530 | +def test_radon_model_compile_variants_benchmark( |
| 1531 | + mode, radon_model, radon_model_variants, benchmark |
| 1532 | +): |
| 1533 | + """Test compilation speed when a slightly variant of a function is compiled each time. |
| 1534 | +
|
| 1535 | + This test more realistically simulates a use case where a model is recompiled |
| 1536 | + multiple times with small changes, such as in an interactive environment. |
| 1537 | +
|
| 1538 | + NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared |
| 1539 | + """ |
| 1540 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1541 | + rng = np.random.default_rng(1) |
| 1542 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1543 | + |
| 1544 | + # Compile base function once to populate the cache |
| 1545 | + fn = function( |
| 1546 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1547 | + ) |
| 1548 | + fn(x) |
| 1549 | + |
| 1550 | + def compile_and_call_once(): |
| 1551 | + for joined_inputs, [model_logp, model_dlogp] in radon_model_variants: |
| 1552 | + fn = function( |
| 1553 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1554 | + ) |
| 1555 | + fn(x) |
| 1556 | + |
| 1557 | + benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1) |
| 1558 | + |
| 1559 | + |
| 1560 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"]) |
| 1561 | +def test_radon_model_call_benchmark(mode, radon_model, benchmark): |
| 1562 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1563 | + |
| 1564 | + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode |
| 1565 | + fn = function( |
| 1566 | + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True |
| 1567 | + ) |
| 1568 | + if mode == "C_VM_NOGC": |
| 1569 | + fn.vm.allow_gc = False |
| 1570 | + |
| 1571 | + rng = np.random.default_rng(1) |
| 1572 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1573 | + fn(x) # warmup |
| 1574 | + |
| 1575 | + benchmark(fn, x) |
0 commit comments