1- from collections .abc import Mapping , Sequence
1+ from collections .abc import Callable , Mapping , Sequence
22
33import numpy as np
4+ import keras
45import matplotlib .pyplot as plt
56
67from ...utils .plot_utils import prepare_plot_data , add_titles_and_labels , prettify_subplots
@@ -13,6 +14,7 @@ def calibration_ecdf(
1314 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1415 variable_keys : Sequence [str ] = None ,
1516 variable_names : Sequence [str ] = None ,
17+ test_quantities : dict [str , Callable ] = None ,
1618 difference : bool = False ,
1719 stacked : bool = False ,
1820 rank_type : str | np .ndarray = "fractional" ,
@@ -78,6 +80,18 @@ def calibration_ecdf(
7880 variable_names : list or None, optional, default: None
7981 The parameter names for nice plot titles.
8082 Inferred if None. Only relevant if `stacked=False`.
83+ test_quantities : dict or None, optional, default: None
84+ A dict that maps plot titles to functions that compute
85+ test quantities based on estimate/target draws.
86+
87+ The dict keys are automatically added to ``variable_keys``
88+ and ``variable_names``.
89+ Test quantity functions are expected to accept a dict of draws with
90+ shape ``(batch_size, ...)`` as the first (typically only)
91+ positional argument and return an NumPy array of shape
92+ ``(batch_size,)``.
93+ The functions do not have to deal with an additional
94+ sample dimension, as appropriate reshaping is done internally.
8195 figsize : tuple or None, optional, default: None
8296 The figure size passed to the matplotlib constructor.
8397 Inferred if None.
@@ -120,6 +134,36 @@ def calibration_ecdf(
120134 If an unknown `rank_type` is passed.
121135 """
122136
137+ # Optionally, compute and prepend test quantities from draws
138+ if test_quantities is not None :
139+ test_quantities_estimates = {}
140+ test_quantities_targets = {}
141+
142+ for key , test_quantity_fn in test_quantities .items ():
143+ # Apply test_quantity_func to ground-truths
144+ tq_targets = test_quantity_fn (data = targets )
145+ test_quantities_targets [key ] = np .expand_dims (tq_targets , axis = 1 )
146+
147+ # # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148+ num_conditions , num_samples = next (iter (estimates .values ())).shape [:2 ]
149+ flattened_estimates = keras .tree .map_structure (lambda t : np .reshape (t , (- 1 , * t .shape [2 :])), estimates )
150+ flat_tq_estimates = test_quantity_fn (data = flattened_estimates )
151+ test_quantities_estimates [key ] = np .reshape (flat_tq_estimates , (num_conditions , num_samples , 1 ))
152+
153+ # Add custom test quantities to variable keys and names for plotting
154+ # keys and names are set to the test_quantities dict keys
155+ test_quantities_names = list (test_quantities .keys ())
156+
157+ if variable_keys is None :
158+ variable_keys = list (estimates .keys ())
159+
160+ if isinstance (variable_names , list ):
161+ variable_names = test_quantities_names + variable_names
162+
163+ variable_keys = test_quantities_names + variable_keys
164+ estimates = test_quantities_estimates | estimates
165+ targets = test_quantities_targets | targets
166+
123167 plot_data = prepare_plot_data (
124168 estimates = estimates ,
125169 targets = targets ,
0 commit comments