1919
2020import numpy as np
2121
22- from pymc . math import logbern
23- from pymc . pytensorf import floatX
22+ from pytensor import config
23+
2424from pymc .stats .convergence import SamplerWarning
2525from pymc .step_methods .compound import Competence
2626from pymc .step_methods .hmc import integration
@@ -205,11 +205,12 @@ def _hamiltonian_step(self, start, p0, step_size):
205205 else :
206206 max_treedepth = self .max_treedepth
207207
208- tree = _Tree (len (p0 ), self .integrator , start , step_size , self .Emax , rng = self .rng )
208+ rng = self .rng
209+ tree = _Tree (len (p0 ), self .integrator , start , step_size , self .Emax , rng = rng )
209210
210211 reached_max_treedepth = False
211212 for _ in range (max_treedepth ):
212- direction = logbern ( np . log ( 0.5 ), rng = self . rng ) * 2 - 1
213+ direction = ( rng . random () < 0.5 ) * 2 - 1
213214 divergence_info , turning = tree .extend (direction )
214215
215216 if divergence_info or turning :
@@ -218,9 +219,8 @@ def _hamiltonian_step(self, start, p0, step_size):
218219 reached_max_treedepth = not self .tune
219220
220221 stats = tree .stats ()
221- accept_stat = stats ["mean_tree_accept" ]
222222 stats ["reached_max_treedepth" ] = reached_max_treedepth
223- return HMCStepData (tree .proposal , accept_stat , divergence_info , stats )
223+ return HMCStepData (tree .proposal , stats [ "mean_tree_accept" ] , divergence_info , stats )
224224
225225 @staticmethod
226226 def competence (var , has_grad ):
@@ -241,6 +241,27 @@ def competence(var, has_grad):
241241
242242
243243class _Tree :
244+ __slots__ = (
245+ "ndim" ,
246+ "integrator" ,
247+ "start" ,
248+ "step_size" ,
249+ "Emax" ,
250+ "start_energy" ,
251+ "rng" ,
252+ "left" ,
253+ "right" ,
254+ "proposal" ,
255+ "depth" ,
256+ "log_size" ,
257+ "log_accept_sum" ,
258+ "mean_tree_accept" ,
259+ "n_proposals" ,
260+ "p_sum" ,
261+ "max_energy_change" ,
262+ "floatX" ,
263+ )
264+
244265 def __init__ (
245266 self ,
246267 ndim : int ,
@@ -273,14 +294,15 @@ def __init__(
273294 self .rng = rng
274295
275296 self .left = self .right = start
276- self .proposal = Proposal (start .q . data , start .q_grad , start .energy , start .model_logp , 0 )
297+ self .proposal = Proposal (start .q , start .q_grad , start .energy , start .model_logp , 0 )
277298 self .depth = 0
278299 self .log_size = 0.0
279300 self .log_accept_sum = - np .inf
280301 self .mean_tree_accept = 0.0
281302 self .n_proposals = 0
282303 self .p_sum = start .p .copy ()
283304 self .max_energy_change = 0.0
305+ self .floatX = config .floatX
284306
285307 def extend (self , direction ):
286308 """Double the treesize by extending the tree in the given direction.
@@ -296,7 +318,7 @@ def extend(self, direction):
296318 """
297319 if direction > 0 :
298320 tree , diverging , turning = self ._build_subtree (
299- self .right , self .depth , floatX ( np .asarray (self .step_size ) )
321+ self .right , self .depth , np .asarray (self .step_size , dtype = self . floatX )
300322 )
301323 leftmost_begin , leftmost_end = self .left , self .right
302324 rightmost_begin , rightmost_end = tree .left , tree .right
@@ -305,7 +327,7 @@ def extend(self, direction):
305327 self .right = tree .right
306328 else :
307329 tree , diverging , turning = self ._build_subtree (
308- self .left , self .depth , floatX ( np .asarray (- self .step_size ) )
330+ self .left , self .depth , np .asarray (- self .step_size , dtype = self . floatX )
309331 )
310332 leftmost_begin , leftmost_end = tree .right , tree .left
311333 rightmost_begin , rightmost_end = self .left , self .right
@@ -318,23 +340,27 @@ def extend(self, direction):
318340 if diverging or turning :
319341 return diverging , turning
320342
321- size1 , size2 = self .log_size , tree .log_size
322- if logbern ( size2 - size1 , rng = self .rng ):
343+ self_log_size , tree_log_size = self .log_size , tree .log_size
344+ if np . log ( self .rng . random ()) < ( tree_log_size - self_log_size ):
323345 self .proposal = tree .proposal
324346
325- self .log_size = np .logaddexp (self .log_size , tree .log_size )
326- self .p_sum [:] += tree .p_sum
347+ self .log_size = np .logaddexp (tree_log_size , self_log_size )
348+
349+ p_sum = self .p_sum
350+ p_sum [:] += tree .p_sum
327351
328352 # Additional turning check only when tree depth > 0 to avoid redundant work
329353 if self .depth > 0 :
330354 left , right = self .left , self .right
331- p_sum = self .p_sum
332355 turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
333- p_sum1 = leftmost_p_sum + rightmost_begin .p
334- turning1 = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (p_sum1 .dot (rightmost_begin .v ) <= 0 )
335- p_sum2 = leftmost_end .p + rightmost_p_sum
336- turning2 = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (p_sum2 .dot (rightmost_end .v ) <= 0 )
337- turning = turning | turning1 | turning2
356+ if not turning :
357+ p_sum1 = leftmost_p_sum + rightmost_begin .p
358+ turning = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (
359+ p_sum1 .dot (rightmost_begin .v ) <= 0
360+ )
361+ if not turning :
362+ p_sum2 = leftmost_end .p + rightmost_p_sum
363+ turning = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (p_sum2 .dot (rightmost_end .v ) <= 0 )
338364
339365 return diverging , turning
340366
@@ -356,7 +382,10 @@ def _single_step(self, left: State, epsilon: float):
356382 if np .isnan (energy_change ):
357383 energy_change = np .inf
358384
359- self .log_accept_sum = np .logaddexp (self .log_accept_sum , min (0 , - energy_change ))
385+ self .log_accept_sum = np .logaddexp (
386+ self .log_accept_sum , (- energy_change if energy_change > 0 else 0 )
387+ )
388+ # self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))
360389
361390 if np .abs (energy_change ) > np .abs (self .max_energy_change ):
362391 self .max_energy_change = energy_change
@@ -366,7 +395,7 @@ def _single_step(self, left: State, epsilon: float):
366395 # Saturated Metropolis accept probability with Boltzmann weight
367396 log_size = - energy_change
368397 proposal = Proposal (
369- right .q . data ,
398+ right .q ,
370399 right .q_grad ,
371400 right .energy ,
372401 right .model_logp ,
@@ -399,15 +428,15 @@ def _build_subtree(self, left, depth, epsilon):
399428 p_sum = tree1 .p_sum + tree2 .p_sum
400429 turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
401430 # Additional U turn check only when depth > 1 to avoid redundant work.
402- if depth - 1 > 0 :
431+ if ( not turning ) and ( depth - 1 > 0 ) :
403432 p_sum1 = tree1 .p_sum + tree2 .left .p
404- turning1 = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (p_sum1 .dot (tree2 .left .v ) <= 0 )
405- p_sum2 = tree1 . right . p + tree2 . p_sum
406- turning2 = ( p_sum2 . dot ( tree1 . right . v ) <= 0 ) or ( p_sum2 . dot ( tree2 . right .v ) <= 0 )
407- turning = turning | turning1 | turning2
433+ turning = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (p_sum1 .dot (tree2 .left .v ) <= 0 )
434+ if not turning :
435+ p_sum2 = tree1 . right .p + tree2 . p_sum
436+ turning = ( p_sum2 . dot ( tree1 . right . v ) <= 0 ) or ( p_sum2 . dot ( tree2 . right . v ) <= 0 )
408437
409438 log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
410- if logbern ( tree2 .log_size - log_size , rng = self . rng ):
439+ if np . log ( self . rng . random ()) < ( tree2 .log_size - log_size ):
411440 proposal = tree2 .proposal
412441 else :
413442 proposal = tree1 .proposal
0 commit comments