|
15 | 15 | from pymc.util import RandomState |
16 | 16 | from pytensor import Variable, graph_replace |
17 | 17 | from pytensor.compile import get_mode |
| 18 | +from rich.box import SIMPLE_HEAD |
| 19 | +from rich.console import Console |
| 20 | +from rich.table import Table |
18 | 21 |
|
19 | 22 | from pymc_extras.statespace.core.representation import PytensorRepresentation |
20 | 23 | from pymc_extras.statespace.filters import ( |
@@ -254,53 +257,72 @@ def __init__( |
254 | 257 | self.kalman_smoother = KalmanSmoother() |
255 | 258 | self.make_symbolic_graph() |
256 | 259 |
|
257 | | - if verbose: |
258 | | - # These are split into separate try-except blocks, because it will be quite rare of models to implement |
259 | | - # _print_data_requirements, but we still want to print the prior requirements. |
260 | | - try: |
261 | | - self._print_prior_requirements() |
262 | | - except NotImplementedError: |
263 | | - pass |
264 | | - try: |
265 | | - self._print_data_requirements() |
266 | | - except NotImplementedError: |
267 | | - pass |
268 | | - |
269 | | - def _print_prior_requirements(self) -> None: |
270 | | - """ |
271 | | - Prints a short report to the terminal about the priors needed for the model, including their names, |
| 260 | + self.requirement_table = None |
| 261 | + self._populate_prior_requirements() |
| 262 | + self._populate_data_requirements() |
| 263 | + |
| 264 | + if verbose and self.requirement_table: |
| 265 | + console = Console() |
| 266 | + console.print(self.requirement_table) |
| 267 | + |
| 268 | + def _populate_prior_requirements(self) -> None: |
| 269 | + """ |
| 270 | + Add requirements about priors needed for the model to a rich table, including their names, |
272 | 271 | shapes, named dimensions, and any parameter constraints. |
273 | 272 | """ |
274 | | - out = "" |
275 | | - for param, info in self.param_info.items(): |
276 | | - out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n' |
277 | | - out = out.rstrip() |
| 273 | + # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either |
| 274 | + # is not true. |
| 275 | + try: |
| 276 | + if not isinstance(self.param_info, dict): |
| 277 | + return |
| 278 | + except NotImplementedError: |
| 279 | + return |
278 | 280 |
|
279 | | - _log.info( |
280 | | - "The following parameters should be assigned priors inside a PyMC " |
281 | | - f"model block: \n" |
282 | | - f"{out}" |
283 | | - ) |
| 281 | + if self.requirement_table is None: |
| 282 | + self._initialize_requirement_table() |
284 | 283 |
|
285 | | - def _print_data_requirements(self) -> None: |
| 284 | + for param, info in self.param_info.items(): |
| 285 | + self.requirement_table.add_row( |
| 286 | + param, str(info["shape"]), info["constraints"], str(info["dims"]) |
| 287 | + ) |
| 288 | + |
| 289 | + def _populate_data_requirements(self) -> None: |
286 | 290 | """ |
287 | | - Prints a short report to the terminal about the data needed for the model, including their names, shapes, |
288 | | - and named dimensions. |
| 291 | + Add requirements about the data needed for the model, including their names, shapes, and named dimensions. |
289 | 292 | """ |
290 | | - if not self.data_info: |
| 293 | + try: |
| 294 | + if not isinstance(self.data_info, dict): |
| 295 | + return |
| 296 | + except NotImplementedError: |
291 | 297 | return |
292 | 298 |
|
293 | | - out = "" |
| 299 | + if self.requirement_table is None: |
| 300 | + self._initialize_requirement_table() |
| 301 | + else: |
| 302 | + self.requirement_table.add_section() |
| 303 | + |
294 | 304 | for data, info in self.data_info.items(): |
295 | | - out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n' |
296 | | - out = out.rstrip() |
| 305 | + self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"])) |
| 306 | + |
| 307 | + def _initialize_requirement_table(self) -> None: |
| 308 | + self.requirement_table = Table( |
| 309 | + show_header=True, |
| 310 | + show_edge=True, |
| 311 | + box=SIMPLE_HEAD, |
| 312 | + highlight=True, |
| 313 | + ) |
297 | 314 |
|
298 | | - _log.info( |
299 | | - "The following Data variables should be assigned to the model inside a PyMC " |
300 | | - f"model block: \n" |
301 | | - f"{out}" |
| 315 | + self.requirement_table.title = "Model Requirements" |
| 316 | + self.requirement_table.caption = ( |
| 317 | + "These parameters should be assigned priors inside a PyMC model block before " |
| 318 | + "calling the build_statespace_graph method." |
302 | 319 | ) |
303 | 320 |
|
| 321 | + self.requirement_table.add_column("Variable", justify="left") |
| 322 | + self.requirement_table.add_column("Shape", justify="left") |
| 323 | + self.requirement_table.add_column("Constraints", justify="left") |
| 324 | + self.requirement_table.add_column("Dimensions", justify="right") |
| 325 | + |
304 | 326 | def _unpack_statespace_with_placeholders( |
305 | 327 | self, |
306 | 328 | ) -> tuple[ |
|
0 commit comments