@@ -162,6 +162,7 @@ def __init__(
162162 mode = None ,
163163 rng = None ,
164164 initial_point : PointType | None = None ,
165+ compile_kwargs : dict | None = None ,
165166 blocked : bool = False ,
166167 ):
167168 """Create an instance of a Metropolis stepper.
@@ -254,7 +255,7 @@ def __init__(
254255 self .mode = mode
255256
256257 shared = pm .make_shared_replacements (initial_point , vars , model )
257- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
258+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
258259 super ().__init__ (vars , shared , blocked = blocked , rng = rng )
259260
260261 def reset_tuning (self ):
@@ -432,6 +433,7 @@ def __init__(
432433 model = None ,
433434 rng = None ,
434435 initial_point : PointType | None = None ,
436+ compile_kwargs : dict | None = None ,
435437 blocked : bool = True ,
436438 ):
437439 model = pm .modelcontext (model )
@@ -447,7 +449,9 @@ def __init__(
447449 if not all (v .dtype in pm .discrete_types for v in vars ):
448450 raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
449451
450- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
452+ if compile_kwargs is None :
453+ compile_kwargs = {}
454+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
451455
452456 def astep (self , apoint : RaveledVars , * args ) -> tuple [RaveledVars , StatsType ]:
453457 logp = args [0 ]
@@ -554,6 +558,7 @@ def __init__(
554558 model = None ,
555559 rng = None ,
556560 initial_point : PointType | None = None ,
561+ compile_kwargs : dict | None = None ,
557562 blocked : bool = True ,
558563 ):
559564 model = pm .modelcontext (model )
@@ -582,7 +587,10 @@ def __init__(
582587 if not all (v .dtype in pm .discrete_types for v in vars ):
583588 raise ValueError ("All variables must be binary for BinaryGibbsMetropolis" )
584589
585- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
590+ if compile_kwargs is None :
591+ compile_kwargs = {}
592+
593+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
586594
587595 def reset_tuning (self ):
588596 # There are no tuning parameters in this step method.
@@ -672,6 +680,7 @@ def __init__(
672680 model = None ,
673681 rng : RandomGenerator = None ,
674682 initial_point : PointType | None = None ,
683+ compile_kwargs : dict | None = None ,
675684 blocked : bool = True ,
676685 ):
677686 model = pm .modelcontext (model )
@@ -728,7 +737,9 @@ def __init__(
728737 # that indicates whether a draw was done in a tuning phase.
729738 self .tune = True
730739
731- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
740+ if compile_kwargs is None :
741+ compile_kwargs = {}
742+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
732743
733744 def reset_tuning (self ):
734745 # There are no tuning parameters in this step method.
@@ -904,6 +915,7 @@ def __init__(
904915 mode = None ,
905916 rng = None ,
906917 initial_point : PointType | None = None ,
918+ compile_kwargs : dict | None = None ,
907919 blocked : bool = True ,
908920 ):
909921 model = pm .modelcontext (model )
@@ -939,7 +951,7 @@ def __init__(
939951 self .mode = mode
940952
941953 shared = pm .make_shared_replacements (initial_point , vars , model )
942- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
954+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
943955 super ().__init__ (vars , shared , blocked = blocked , rng = rng )
944956
945957 def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
@@ -1073,6 +1085,7 @@ def __init__(
10731085 tune_drop_fraction : float = 0.9 ,
10741086 model = None ,
10751087 initial_point : PointType | None = None ,
1088+ compile_kwargs : dict | None = None ,
10761089 mode = None ,
10771090 rng = None ,
10781091 blocked : bool = True ,
@@ -1122,7 +1135,7 @@ def __init__(
11221135 self .mode = mode
11231136
11241137 shared = pm .make_shared_replacements (initial_point , vars , model )
1125- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
1138+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
11261139 super ().__init__ (vars , shared , blocked = blocked , rng = rng )
11271140
11281141 def reset_tuning (self ):
@@ -1213,6 +1226,7 @@ def delta_logp(
12131226 logp : pt .TensorVariable ,
12141227 vars : list [pt .TensorVariable ],
12151228 shared : dict [pt .TensorVariable , pt .sharedvar .TensorSharedVariable ],
1229+ compile_kwargs : dict | None ,
12161230) -> pytensor .compile .Function :
12171231 [logp0 ], inarray0 = join_nonshared_inputs (
12181232 point = point , outputs = [logp ], inputs = vars , shared_inputs = shared
@@ -1225,6 +1239,8 @@ def delta_logp(
12251239 # Replace any potential duplicated RNG nodes
12261240 (logp1 ,) = replace_rng_nodes ((logp1 ,))
12271241
1228- f = compile_pymc ([inarray1 , inarray0 ], logp1 - logp0 )
1242+ if compile_kwargs is None :
1243+ compile_kwargs = {}
1244+ f = compile_pymc ([inarray1 , inarray0 ], logp1 - logp0 , ** compile_kwargs )
12291245 f .trust_input = True
12301246 return f
0 commit comments