|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from collections.abc import Mapping, MutableMapping, Sequence |
| 14 | +from collections.abc import Callable, Mapping, MutableMapping, Sequence |
15 | 15 | from typing import Any |
16 | 16 |
|
17 | 17 | import arviz as az |
@@ -91,10 +91,11 @@ def __init__( |
91 | 91 | vars: Sequence[TensorVariable] | None = None, |
92 | 92 | test_point: dict[str, np.ndarray] | None = None, |
93 | 93 | draws_per_chunk: int = 1, |
| 94 | + fn: Callable | None = None, |
94 | 95 | ): |
95 | 96 | if not _zarr_available: |
96 | 97 | raise RuntimeError("You must install zarr to be able to create ZarrChain instances") |
97 | | - super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) |
| 98 | + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) |
98 | 99 | self._step_method: BlockedStep | CompoundStep | None = None |
99 | 100 | self.unconstrained_variables = { |
100 | 101 | var.name for var in self.vars if is_transformed_name(var.name) |
@@ -168,7 +169,7 @@ def record( |
168 | 169 | :meth:`~ZarrChain.flush` |
169 | 170 | """ |
170 | 171 | unconstrained_variables = self.unconstrained_variables |
171 | | - for var_name, var_value in zip(self.varnames, self.fn(draw)): |
| 172 | + for var_name, var_value in zip(self.varnames, self.fn(**draw)): |
172 | 173 | if var_name in unconstrained_variables: |
173 | 174 | self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) |
174 | 175 | else: |
@@ -452,13 +453,18 @@ def init_trace( |
452 | 453 | ) |
453 | 454 | self.vars = [var for var in vars if var.name in self.varnames] |
454 | 455 |
|
455 | | - self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") |
| 456 | + self.fn = model.compile_fn( |
| 457 | + self.vars, |
| 458 | + inputs=model.value_vars, |
| 459 | + on_unused_input="ignore", |
| 460 | + point_fn=False, |
| 461 | + ) |
456 | 462 |
|
457 | 463 | # Get variable shapes. Most backends will need this |
458 | 464 | # information. |
459 | 465 | if test_point is None: |
460 | 466 | test_point = model.initial_point() |
461 | | - var_values = list(zip(self.varnames, self.fn(test_point))) |
| 467 | + var_values = list(zip(self.varnames, self.fn(**test_point))) |
462 | 468 | self.var_dtype_shapes = { |
463 | 469 | var: (value.dtype, value.shape) |
464 | 470 | for var, value in var_values |
@@ -528,6 +534,7 @@ def init_trace( |
528 | 534 | test_point=test_point, |
529 | 535 | stats_bijection=StatsBijection(step.stats_dtypes), |
530 | 536 | draws_per_chunk=self.draws_per_chunk, |
| 537 | + fn=self.fn, |
531 | 538 | ) |
532 | 539 | for _ in range(chains) |
533 | 540 | ] |
|
0 commit comments