2121import warnings
2222
2323from abc import ABC
24- from typing import List , Sequence , Tuple , cast
24+ from typing import Dict , List , Optional , Sequence , Set , Tuple , Union , cast
2525
2626import numpy as np
27- import pytensor .tensor as at
2827
2928from pymc .backends .report import SamplerReport
3029from pymc .model import modelcontext
@@ -210,18 +209,18 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
210209 """Get sampler statistics."""
211210 raise NotImplementedError ()
212211
213- def _slice (self , idx ):
212+ def _slice (self , idx : Union [ int , slice ] ):
214213 """Slice trace object."""
215214 raise NotImplementedError ()
216215
217- def point (self , idx ) :
216+ def point (self , idx : int ) -> Dict [ str , np . ndarray ] :
218217 """Return dictionary of point values at `idx` for current chain
219218 with variables names as keys.
220219 """
221220 raise NotImplementedError ()
222221
223222 @property
224- def stat_names (self ):
223+ def stat_names (self ) -> Set [ str ] :
225224 names = set ()
226225 for vars in self .sampler_vars or []:
227226 names .update (vars .keys ())
@@ -280,12 +279,10 @@ class MultiTrace:
280279 List of variable names in the trace(s)
281280 """
282281
283- def __init__ (self , straces ):
284- self ._straces = {}
285- for strace in straces :
286- if strace .chain in self ._straces :
287- raise ValueError ("Chains are not unique." )
288- self ._straces [strace .chain ] = strace
282+ def __init__ (self , straces : Sequence [BaseTrace ]):
283+ if len ({t .chain for t in straces }) != len (straces ):
284+ raise ValueError ("Chains are not unique." )
285+ self ._straces = {t .chain : t for t in straces }
289286
290287 self ._report = SamplerReport ()
291288
@@ -294,15 +291,15 @@ def __repr__(self):
294291 return template .format (self .__class__ .__name__ , self .nchains , len (self ), len (self .varnames ))
295292
296293 @property
297- def nchains (self ):
294+ def nchains (self ) -> int :
298295 return len (self ._straces )
299296
300297 @property
301- def chains (self ):
298+ def chains (self ) -> List [ int ] :
302299 return list (sorted (self ._straces .keys ()))
303300
304301 @property
305- def report (self ):
302+ def report (self ) -> SamplerReport :
306303 return self ._report
307304
308305 def __iter__ (self ):
@@ -367,12 +364,12 @@ def __len__(self):
367364 return len (self ._straces [chain ])
368365
369366 @property
370- def varnames (self ):
367+ def varnames (self ) -> List [ str ] :
371368 chain = self .chains [- 1 ]
372369 return self ._straces [chain ].varnames
373370
374371 @property
375- def stat_names (self ):
372+ def stat_names (self ) -> Set [ str ] :
376373 if not self ._straces :
377374 return set ()
378375 sampler_vars = [s .sampler_vars for s in self ._straces .values ()]
@@ -386,74 +383,15 @@ def stat_names(self):
386383 names .update (vars .keys ())
387384 return names
388385
389- def add_values (self , vals , overwrite = False ) -> None :
390- """Add variables to traces.
391-
392- Parameters
393- ----------
394- vals: dict (str: array-like)
395- The keys should be the names of the new variables. The values are expected to be
396- array-like objects. For traces with more than one chain the length of each value
397- should match the number of total samples already in the trace `(chains * iterations)`,
398- otherwise a warning is raised.
399- overwrite: bool
400- If `False` (default) a ValueError is raised if the variable already exists.
401- Change to `True` to overwrite the values of variables
402-
403- Returns
404- -------
405- None.
406- """
407- for k , v in vals .items ():
408- new_var = 1
409- if k in self .varnames :
410- if overwrite :
411- self .varnames .remove (k )
412- new_var = 0
413- else :
414- raise ValueError (f"Variable name { k } already exists." )
415-
416- self .varnames .append (k )
417-
418- chains = self ._straces
419- l_samples = len (self ) * len (self .chains )
420- l_v = len (v )
421- if l_v != l_samples :
422- warnings .warn (
423- "The length of the values you are trying to "
424- "add ({}) does not match the number ({}) of "
425- "total samples in the trace "
426- "(chains * iterations)" .format (l_v , l_samples )
427- )
428-
429- v = np .squeeze (v .reshape (len (chains ), len (self ), - 1 ))
430-
431- for idx , chain in enumerate (chains .values ()):
432- if new_var :
433- dummy = at .as_tensor_variable ([], k )
434- chain .vars .append (dummy )
435- chain .samples [k ] = v [idx ]
436-
437- def remove_values (self , name ):
438- """remove variables from traces.
439-
440- Parameters
441- ----------
442- name: str
443- Name of the variable to remove. Raises KeyError if the variable is not present
444- """
445- varnames = self .varnames
446- if name not in varnames :
447- raise KeyError (f"Unknown variable { name } " )
448- self .varnames .remove (name )
449- chains = self ._straces
450- for chain in chains .values ():
451- for va in chain .vars :
452- if va .name == name :
453- chain .vars .remove (va )
454- del chain .samples [name ]
455-
456- def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
386+ def get_values (
387+ self ,
388+ varname : str ,
389+ burn : int = 0 ,
390+ thin : int = 1 ,
391+ combine : bool = True ,
392+ chains : Optional [Union [int , Sequence [int ]]] = None ,
393+ squeeze : bool = True ,
394+ ) -> List [np .ndarray ]:
457395 """Get values from traces.
458396
459397 Parameters
@@ -479,13 +417,20 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze
479417 if chains is None :
480418 chains = self .chains
481419 varname = get_var_name (varname )
482- try :
483- results = [self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains ]
484- except TypeError : # Single chain passed.
485- results = [self ._straces [chains ].get_values (varname , burn , thin )]
420+ if isinstance (chains , int ):
421+ chains = [chains ]
422+ results = [self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains ]
486423 return _squeeze_cat (results , combine , squeeze )
487424
488- def get_sampler_stats (self , stat_name , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
425+ def get_sampler_stats (
426+ self ,
427+ stat_name : str ,
428+ burn : int = 0 ,
429+ thin : int = 1 ,
430+ combine : bool = True ,
431+ chains : Optional [Union [int , Sequence [int ]]] = None ,
432+ squeeze : bool = True ,
433+ ):
489434 """Get sampler statistics from the trace.
490435
491436 Parameters
@@ -508,9 +453,7 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None
508453
509454 if chains is None :
510455 chains = self .chains
511- try :
512- chains = iter (chains )
513- except TypeError :
456+ if isinstance (chains , int ):
514457 chains = [chains ]
515458
516459 results = [
@@ -526,7 +469,7 @@ def _slice(self, slice):
526469 trace ._report = self ._report ._slice (* idxs )
527470 return trace
528471
529- def point (self , idx , chain = None ):
472+ def point (self , idx : int , chain : Optional [ int ] = None ) -> Dict [ str , np . ndarray ] :
530473 """Return a dictionary of point values at `idx`.
531474
532475 Parameters
0 commit comments