Skip to content

Commit 9b7bbc8

Browse files
committed
fix: pymc with jax backend was broken with some shapes
1 parent 9c69449 commit 9b7bbc8

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

python/nutpie/compile_pymc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,10 @@ def make_expand_func(seed1, seed2, chain):
445445
def expand(_x, **shared):
446446
values = expand_fn(_x, *[shared[name] for name in expand_shared_names])
447447
return {
448-
name: np.asarray(val, order="C", dtype=dtype).ravel()
449-
for name, val, dtype in zip(names, values, dtypes, strict=True)
448+
name: np.asarray(val, order="C", dtype=dtype).reshape(shape)
449+
for name, val, dtype, shape in zip(
450+
names, values, dtypes, shapes, strict=True
451+
)
450452
}
451453

452454
return expand

src/pyfunc.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,17 +289,27 @@ impl CpuLogpFunc for PyDensity {
289289
))
290290
})?;
291291
if !arr.is_c_contiguous() {
292-
return Err(nuts_rs::CpuMathError::ExpandError(
293-
"not c contiguous".into(),
294-
));
292+
return Err(nuts_rs::CpuMathError::ExpandError(format!(
293+
"not c contiguous: {}",
294+
var.name
295+
)));
296+
}
297+
if arr.shape().len() != var.shape.as_slice().len() {
298+
return Err(nuts_rs::CpuMathError::ExpandError(format!(
299+
"unexpected number of dimensions for variable {}",
300+
var.name
301+
)));
295302
}
296303
if !arr
297304
.shape()
298305
.iter()
299306
.zip(var.shape.as_slice())
300307
.all(|(a, &b)| *a as u64 == b)
301308
{
302-
return Err(nuts_rs::CpuMathError::ExpandError("upected shape".into()));
309+
return Err(nuts_rs::CpuMathError::ExpandError(format!(
310+
"unexpected shape for variable {}",
311+
var.name
312+
)));
303313
}
304314
Ok(arr)
305315
}

tests/test_pymc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def test_low_rank(backend, gradient_backend):
100100
@parameterize_backends
101101
def test_low_rank_half_normal(backend, gradient_backend):
102102
with pm.Model() as model:
103-
pm.HalfNormal("a", shape=13)
103+
pm.HalfNormal("a", shape=(13, 3))
104+
pm.HalfNormal("b", shape=())
105+
pm.HalfNormal("c", shape=(5,))
104106

105107
compiled = nutpie.compile_pymc_model(
106108
model, backend=backend, gradient_backend=gradient_backend

0 commit comments

Comments
 (0)