2121from arviz .data .base import CoordSpec , DimSpec , dict_to_dataset , requires
2222from pytensor .graph .basic import Constant
2323from pytensor .tensor .sharedvar import SharedVariable
24- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
2524
2625import pymc
2726
@@ -153,7 +152,7 @@ def __init__(
153152 trace = None ,
154153 prior = None ,
155154 posterior_predictive = None ,
156- log_likelihood = True ,
155+ log_likelihood = False ,
157156 predictions = None ,
158157 coords : Optional [CoordSpec ] = None ,
159158 dims : Optional [DimSpec ] = None ,
@@ -246,68 +245,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
246245 trace_posterior = self .trace [self .ntune :]
247246 return trace_posterior , trace_warmup
248247
249- def log_likelihood_vals_point (self , point , var , log_like_fun ):
250- """Compute log likelihood for each observed point."""
251- # TODO: This is a cheap hack; we should filter-out the correct
252- # variables some other way
253- point = {i .name : point [i .name ] for i in log_like_fun .f .maker .inputs if i .name in point }
254- log_like_val = np .atleast_1d (log_like_fun (point ))
255-
256- if isinstance (var .owner .op , (AdvancedIncSubtensor , AdvancedIncSubtensor1 )):
257- try :
258- obs_data = extract_obs_data (self .model .rvs_to_values [var ])
259- except TypeError :
260- warnings .warn (f"Could not extract data from symbolic observation { var } " )
261-
262- mask = obs_data .mask
263- if np .ndim (mask ) > np .ndim (log_like_val ):
264- mask = np .any (mask , axis = - 1 )
265- log_like_val = np .where (mask , np .nan , log_like_val )
266- return log_like_val
267-
268- def _extract_log_likelihood (self , trace ):
269- """Compute log likelihood of each observation."""
270- if self .trace is None :
271- return None
272- if self .model is None :
273- return None
274-
275- # TODO: We no longer need one function per observed variable
276- if self .log_likelihood is True :
277- cached = [
278- (
279- var ,
280- self .model .compile_fn (
281- self .model .logp (var , sum = False )[0 ],
282- inputs = self .model .value_vars ,
283- on_unused_input = "ignore" ,
284- ),
285- )
286- for var in self .model .observed_RVs
287- ]
288- else :
289- cached = [
290- (
291- var ,
292- self .model .compile_fn (
293- self .model .logp (var , sum = False )[0 ],
294- inputs = self .model .value_vars ,
295- on_unused_input = "ignore" ,
296- ),
297- )
298- for var in self .model .observed_RVs
299- if var .name in self .log_likelihood
300- ]
301- log_likelihood_dict = _DefaultTrace (len (trace .chains ))
302- for var , log_like_fun in cached :
303- for k , chain in enumerate (trace .chains ):
304- log_like_chain = [
305- self .log_likelihood_vals_point (point , var , log_like_fun )
306- for point in trace .points ([chain ])
307- ]
308- log_likelihood_dict .insert (var .name , np .stack (log_like_chain ), k )
309- return log_likelihood_dict .trace_dict
310-
311248 @requires ("trace" )
312249 def posterior_to_xarray (self ):
313250 """Convert the posterior to an xarray dataset."""
@@ -382,49 +319,6 @@ def sample_stats_to_xarray(self):
382319 ),
383320 )
384321
385- @requires ("trace" )
386- @requires ("model" )
387- def log_likelihood_to_xarray (self ):
388- """Extract log likelihood and log_p data from PyMC trace."""
389- if self .predictions or not self .log_likelihood :
390- return None
391- data_warmup = {}
392- data = {}
393- warn_msg = (
394- "Could not compute log_likelihood, it will be omitted. "
395- "Check your model object or set log_likelihood=False"
396- )
397- if self .posterior_trace :
398- try :
399- data = self ._extract_log_likelihood (self .posterior_trace )
400- except TypeError :
401- warnings .warn (warn_msg )
402- if self .warmup_trace :
403- try :
404- data_warmup = self ._extract_log_likelihood (self .warmup_trace )
405- except TypeError :
406- warnings .warn (warn_msg )
407- return (
408- dict_to_dataset (
409- data ,
410- library = pymc ,
411- dims = self .dims ,
412- coords = self .coords ,
413- skip_event_dims = True ,
414- ),
415- dict_to_dataset (
416- data_warmup ,
417- library = pymc ,
418- dims = self .dims ,
419- coords = self .coords ,
420- skip_event_dims = True ,
421- ),
422- )
423-
424- return dict_to_dataset (
425- data , library = pymc , coords = self .coords , dims = self .dims , default_dims = self .sample_dims
426- )
427-
428322 @requires (["posterior_predictive" ])
429323 def posterior_predictive_to_xarray (self ):
430324 """Convert posterior_predictive samples to xarray."""
@@ -509,7 +403,6 @@ def to_inference_data(self):
509403 id_dict = {
510404 "posterior" : self .posterior_to_xarray (),
511405 "sample_stats" : self .sample_stats_to_xarray (),
512- "log_likelihood" : self .log_likelihood_to_xarray (),
513406 "posterior_predictive" : self .posterior_predictive_to_xarray (),
514407 "predictions" : self .predictions_to_xarray (),
515408 ** self .priors_to_xarray (),
@@ -519,15 +412,27 @@ def to_inference_data(self):
519412 id_dict ["predictions_constant_data" ] = self .constant_data_to_xarray ()
520413 else :
521414 id_dict ["constant_data" ] = self .constant_data_to_xarray ()
522- return InferenceData (save_warmup = self .save_warmup , ** id_dict )
415+ idata = InferenceData (save_warmup = self .save_warmup , ** id_dict )
416+ if self .log_likelihood :
417+ from pymc .stats .log_likelihood import compute_log_likelihood
418+
419+ idata = compute_log_likelihood (
420+ idata ,
421+ var_names = None if self .log_likelihood is True else self .log_likelihood ,
422+ extend_inferencedata = True ,
423+ model = self .model ,
424+ sample_dims = self .sample_dims ,
425+ progressbar = False ,
426+ )
427+ return idata
523428
524429
525430def to_inference_data (
526431 trace : Optional ["MultiTrace" ] = None ,
527432 * ,
528433 prior : Optional [Mapping [str , Any ]] = None ,
529434 posterior_predictive : Optional [Mapping [str , Any ]] = None ,
530- log_likelihood : Union [bool , Iterable [str ]] = True ,
435+ log_likelihood : Union [bool , Iterable [str ]] = False ,
531436 coords : Optional [CoordSpec ] = None ,
532437 dims : Optional [DimSpec ] = None ,
533438 sample_dims : Optional [List ] = None ,
0 commit comments