Skip to content

Commit 001d418

Browse files
Add dims/coords consistency validation before sampling (issue #7891)
This commit implements validation of InferenceData dimensions and coordinates consistency before pm.sample() executes, preventing cryptic shape mismatch errors during sampling. Changes: - Add pymc/model/validation.py with validation functions: * validate_dims_coords_consistency(): Main validation entry point * check_dims_exist(): Verify referenced dims exist in model.coords * check_shape_dims_match(): Verify variable shapes match declared dims * check_coord_lengths(): Verify coordinate lengths match dimension sizes - Integrate validation into pymc/sampling/mcmc.py: * Added validation call early in sample() function, before sampling setup * Provides clear, actionable error messages to guide users - Add comprehensive tests in tests/model/test_dims_coords_validation.py: * Test missing coord detection * Test shape-dims mismatch detection * Test coordinate length validation * Test MutableData, observed data, and Deterministic variables * Test edge cases and complex models Fixes #7891
1 parent a6e6fa8 commit 001d418

File tree

3 files changed

+591
-0
lines changed

3 files changed

+591
-0
lines changed

pymc/model/validation.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Validation utilities for PyMC models.
16+
17+
This module provides functions to validate that model dimensions and coordinates
18+
are consistent before sampling begins, preventing cryptic shape mismatch errors.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import numpy as np
24+
import pytensor.tensor as pt
25+
from pytensor.graph.basic import Variable
26+
from pytensor.tensor.variable import TensorVariable, TensorConstant
27+
28+
try:
29+
unused = TYPE_CHECKING
30+
except NameError:
31+
from typing import TYPE_CHECKING
32+
33+
if TYPE_CHECKING:
34+
from pymc.model.core import Model
35+
36+
__all__ = ["validate_dims_coords_consistency"]
37+
38+
39+
def validate_dims_coords_consistency(model: Model) -> None:
40+
"""Validate that all dims and coords are consistent before sampling.
41+
42+
This function performs comprehensive validation to ensure that:
43+
- All dims referenced in model variables exist in model.coords
44+
- Variable shapes match their declared dimensions
45+
- Coordinate lengths match the corresponding dimension sizes
46+
- MutableData variables have consistent dims when specified
47+
- No conflicting dimension specifications exist across variables
48+
49+
Parameters
50+
----------
51+
model : pm.Model
52+
The PyMC model to validate
53+
54+
Raises
55+
------
56+
ValueError
57+
If inconsistencies are found with detailed error messages that guide
58+
users on how to fix the issues.
59+
"""
60+
errors = []
61+
62+
# Check 1: Verify all referenced dims exist in coords
63+
dims_errors = check_dims_exist(model)
64+
errors.extend(dims_errors)
65+
66+
# Check 2: Verify shape-dim consistency for all model variables
67+
shape_errors = check_shape_dims_match(model)
68+
errors.extend(shape_errors)
69+
70+
# Check 3: Check coordinate length matches dimension size
71+
coord_length_errors = check_coord_lengths(model)
72+
errors.extend(coord_length_errors)
73+
74+
# If any errors were found, raise a comprehensive ValueError
75+
if errors:
76+
error_msg = "\n\n".join(errors)
77+
raise ValueError(
78+
"Model dimension and coordinate inconsistencies detected:\n\n"
79+
+ error_msg
80+
+ "\n\n"
81+
+ "Please fix the above issues before sampling. "
82+
"You may need to add missing coordinates to model.coords, "
83+
"adjust variable shapes, or ensure coordinate values match dimension sizes."
84+
)
85+
86+
87+
def check_dims_exist(model: Model) -> list[str]:
88+
"""Check that all dims referenced in variables exist in model.coords.
89+
90+
Parameters
91+
----------
92+
model : Model
93+
The PyMC model to check
94+
95+
Returns
96+
-------
97+
list[str]
98+
List of error messages (empty if no errors)
99+
"""
100+
errors = []
101+
all_referenced_dims = set()
102+
103+
# Collect all dims referenced across all variables
104+
for var_name, dims in model.named_vars_to_dims.items():
105+
if dims is not None:
106+
for dim in dims:
107+
if dim is not None:
108+
all_referenced_dims.add(dim)
109+
110+
# Check each referenced dim exists in model.coords
111+
missing_dims = all_referenced_dims - set(model.coords.keys())
112+
113+
if missing_dims:
114+
# Group variables by missing dims for better error messages
115+
dim_to_vars = {}
116+
for var_name, dims in model.named_vars_to_dims.items():
117+
if dims is not None:
118+
for dim in dims:
119+
if dim in missing_dims:
120+
dim_to_vars.setdefault(dim, []).append(var_name)
121+
122+
for dim in sorted(missing_dims):
123+
var_names = sorted(set(dim_to_vars[dim]))
124+
var_list = ", ".join([f"'{v}'" for v in var_names])
125+
errors.append(
126+
f"Dimension '{dim}' is referenced by variable(s) {var_list}, "
127+
f"but it is not defined in model.coords. "
128+
f"Add '{dim}' to model.coords, for example:\n"
129+
f" model.add_coord('{dim}', values=range(n)) # or specific coordinate values"
130+
)
131+
132+
return errors
133+
134+
135+
def check_shape_dims_match(model: Model) -> list[str]:
136+
"""Check that variable shapes match their declared dims.
137+
138+
This checks that if a variable declares dims, its shape matches the
139+
sizes of those dimensions as defined in model.coords.
140+
141+
Parameters
142+
----------
143+
model : Model
144+
The PyMC model to check
145+
146+
Returns
147+
-------
148+
list[str]
149+
List of error messages (empty if no errors)
150+
"""
151+
errors = []
152+
153+
for var_name, dims in model.named_vars_to_dims.items():
154+
if dims is None or not dims:
155+
continue
156+
157+
var = model.named_vars.get(var_name)
158+
if var is None:
159+
continue
160+
161+
# Skip if variable doesn't have shape (e.g., scalars)
162+
if not hasattr(var, "shape") or not hasattr(var, "ndim"):
163+
continue
164+
165+
# Get expected shape from dims
166+
expected_shape = []
167+
dim_names = []
168+
for d, dim_name in enumerate(dims):
169+
if dim_name is None:
170+
# If dim is None, we can't validate against coords
171+
# This is valid for variables with mixed dims/None
172+
continue
173+
174+
if dim_name not in model.coords:
175+
# Already reported by check_dims_exist, skip here
176+
continue
177+
178+
# Get dimension length
179+
coord = model.coords[dim_name]
180+
if coord is not None:
181+
dim_length = len(coord)
182+
else:
183+
# Symbolic dimension - get from dim_lengths
184+
dim_length_var = model.dim_lengths.get(dim_name)
185+
if dim_length_var is not None:
186+
try:
187+
# Try to evaluate if it's a constant
188+
if isinstance(dim_length_var, pt.TensorConstant):
189+
dim_length = int(dim_length_var.data)
190+
else:
191+
# Symbolic, skip this check
192+
continue
193+
except (AttributeError, TypeError, ValueError):
194+
# Can't evaluate, skip
195+
continue
196+
else:
197+
continue
198+
199+
expected_shape.append(dim_length)
200+
dim_names.append(dim_name)
201+
202+
if not expected_shape:
203+
# Couldn't determine expected shape, skip
204+
continue
205+
206+
# For variables with symbolic shapes, we need to try to evaluate
207+
try:
208+
actual_shape = var.shape
209+
if isinstance(actual_shape, (list, tuple)):
210+
# Replace symbolic shape elements if possible
211+
evaluated_shape = []
212+
shape_idx = 0
213+
for dim_name in dims:
214+
if dim_name is None:
215+
# Skip None dims
216+
if shape_idx < len(actual_shape):
217+
evaluated_shape.append(actual_shape[shape_idx])
218+
shape_idx += 1
219+
continue
220+
221+
if dim_name not in model.coords:
222+
if shape_idx < len(actual_shape):
223+
shape_idx += 1
224+
continue
225+
226+
if shape_idx < len(actual_shape):
227+
shape_elem = actual_shape[shape_idx]
228+
# Try to evaluate if symbolic
229+
if isinstance(shape_elem, pt.TensorConstant):
230+
evaluated_shape.append(int(shape_elem.data))
231+
elif isinstance(shape_elem, Variable):
232+
try:
233+
evaluated = shape_elem.eval()
234+
if np.isscalar(evaluated):
235+
evaluated_shape.append(int(evaluated))
236+
else:
237+
evaluated_shape.append(None) # Can't validate
238+
except Exception:
239+
evaluated_shape.append(None) # Can't validate
240+
else:
241+
evaluated_shape.append(int(shape_elem) if shape_elem is not None else None)
242+
shape_idx += 1
243+
244+
# Compare only elements we could evaluate
245+
if len(evaluated_shape) != len(expected_shape):
246+
# Different number of dimensions, skip
247+
continue
248+
249+
mismatches = []
250+
for i, (actual, expected) in enumerate(zip(evaluated_shape, expected_shape)):
251+
if actual is not None and actual != expected:
252+
mismatches.append(
253+
f" dimension {i} (dim='{dim_names[i]}'): got {actual}, expected {expected}"
254+
)
255+
256+
if mismatches:
257+
errors.append(
258+
f"Variable '{var_name}' declares dims {dims} but its shape "
259+
f"does not match the coordinate lengths:\n"
260+
+ "\n".join(mismatches)
261+
)
262+
except Exception:
263+
# If we can't evaluate the shape, skip this check
264+
# The shape might be symbolic and resolve at runtime
265+
pass
266+
267+
return errors
268+
269+
270+
def check_coord_lengths(model: Model) -> list[str]:
271+
"""Check that coordinate arrays match their dimension sizes.
272+
273+
This validates that when coordinates have values, their length matches
274+
the dimension length. For symbolic dimensions (like MutableData), this
275+
check may be skipped.
276+
277+
Parameters
278+
----------
279+
model : Model
280+
The PyMC model to check
281+
282+
Returns
283+
-------
284+
list[str]
285+
List of error messages (empty if no errors)
286+
"""
287+
errors = []
288+
289+
for dim_name, coord_values in model.coords.items():
290+
if coord_values is None:
291+
# Symbolic dimension, skip
292+
continue
293+
294+
dim_length_var = model.dim_lengths.get(dim_name)
295+
if dim_length_var is None:
296+
continue
297+
298+
try:
299+
# Get actual coordinate length
300+
coord_length = len(coord_values) if coord_values is not None else None
301+
302+
# Get expected dimension length
303+
if isinstance(dim_length_var, pt.TensorConstant):
304+
expected_length = int(dim_length_var.data)
305+
elif isinstance(dim_length_var, Variable):
306+
try:
307+
eval_result = dim_length_var.eval()
308+
if np.isscalar(eval_result):
309+
expected_length = int(eval_result)
310+
else:
311+
# Can't compare, might be symbolic
312+
continue
313+
except Exception:
314+
# Can't evaluate, might be symbolic (e.g., MutableData)
315+
continue
316+
else:
317+
expected_length = int(dim_length_var)
318+
319+
# Compare lengths
320+
if coord_length is not None and coord_length != expected_length:
321+
# Find which variables use this dimension
322+
using_vars = []
323+
for var_name, dims in model.named_vars_to_dims.items():
324+
if dims is not None and dim_name in dims:
325+
using_vars.append(var_name)
326+
327+
var_list = ", ".join([f"'{v}'" for v in sorted(using_vars)]) if using_vars else "variables"
328+
329+
errors.append(
330+
f"Dimension '{dim_name}' has coordinate values of length {coord_length}, "
331+
f"but the dimension size is {expected_length}. "
332+
f"This affects variable(s): {var_list}. "
333+
f"Update the coordinate values to match the dimension size, "
334+
f"or adjust the dimension size to match the coordinates."
335+
)
336+
except Exception:
337+
# If evaluation fails, skip (might be symbolic)
338+
pass
339+
340+
return errors
341+

pymc/sampling/mcmc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from pymc.exceptions import SamplingError
5555
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
5656
from pymc.model import Model, modelcontext
57+
from pymc.model.validation import validate_dims_coords_consistency
5758
from pymc.progress_bar import ProgressBarManager, ProgressBarType, default_progress_theme
5859
from pymc.sampling.parallel import Draw, _cpu_count
5960
from pymc.sampling.population import _sample_population
@@ -716,6 +717,8 @@ def sample(
716717
progress_bool = bool(progressbar)
717718

718719
model = modelcontext(model)
720+
# Validate dims/coords consistency before sampling
721+
validate_dims_coords_consistency(model)
719722
if not model.free_RVs:
720723
raise SamplingError(
721724
"Cannot sample from the model, since the model does not contain any free variables."

0 commit comments

Comments
 (0)