6161
6262from pytensor .graph .basic import Variable
6363from pytensor .graph .replace import graph_replace
64+ from pytensor .scalar .basic import identity as scalar_identity
65+ from pytensor .tensor .elemwise import Elemwise
6466from pytensor .tensor .shape import unbroadcast
6567
6668import pymc as pm
7476 SeedSequenceSeed ,
7577 compile ,
7678 find_rng_nodes ,
77- identity ,
7879 reseed_rngs ,
7980)
8081from pymc .util import (
@@ -332,6 +333,7 @@ def step_function(
332333 more_replacements = None ,
333334 total_grad_norm_constraint = None ,
334335 score = False ,
336+ compile_kwargs = None ,
335337 fn_kwargs = None ,
336338 ):
337339 R"""Step function that should be called on each optimization step.
@@ -362,17 +364,30 @@ def step_function(
362364 Bounds gradient norm, prevents exploding gradient problem
363365 score: `bool`
364366 calculate loss on each step? Defaults to False for speed
365- fn_kwargs : `dict`
367+ compile_kwargs : `dict`
366368 Add kwargs to pytensor.function (e.g. `{'profile': True}`)
369+ fn_kwargs: dict
370+ arbitrary kwargs passed to `pytensor.function`
371+
372+ .. warning:: `fn_kwargs` is deprecated and will be removed in future versions
373+
367374 more_replacements: `dict`
368375 Apply custom replacements before calculating gradients
369376
370377 Returns
371378 -------
372379 `pytensor.function`
373380 """
374- if fn_kwargs is None :
375- fn_kwargs = {}
381+ if fn_kwargs is not None :
382+ warnings .warn (
383+ "`fn_kwargs` is deprecated and will be removed in future versions. Use "
384+ "`compile_kwargs` instead." ,
385+ DeprecationWarning ,
386+ )
387+ compile_kwargs = fn_kwargs
388+
389+ if compile_kwargs is None :
390+ compile_kwargs = {}
376391 if score and not self .op .returns_loss :
377392 raise NotImplementedError (f"{ self .op } does not have loss" )
378393 updates = self .updates (
@@ -388,14 +403,14 @@ def step_function(
388403 )
389404 seed = self .approx .rng .randint (2 ** 30 , dtype = np .int64 )
390405 if score :
391- step_fn = compile ([], updates .loss , updates = updates , random_seed = seed , ** fn_kwargs )
406+ step_fn = compile ([], updates .loss , updates = updates , random_seed = seed , ** compile_kwargs )
392407 else :
393- step_fn = compile ([], [], updates = updates , random_seed = seed , ** fn_kwargs )
408+ step_fn = compile ([], [], updates = updates , random_seed = seed , ** compile_kwargs )
394409 return step_fn
395410
396411 @pytensor .config .change_flags (compute_test_value = "off" )
397412 def score_function (
398- self , sc_n_mc = None , more_replacements = None , fn_kwargs = None
413+ self , sc_n_mc = None , more_replacements = None , compile_kwargs = None , fn_kwargs = None
399414 ): # pragma: no cover
400415 R"""Compile scoring function that operates which takes no inputs and returns Loss.
401416
@@ -405,22 +420,34 @@ def score_function(
405420 number of scoring MC samples
406421 more_replacements:
407422 Apply custom replacements before compiling a function
423+ compile_kwargs: `dict`
424+ arbitrary kwargs passed to `pytensor.function`
408425 fn_kwargs: `dict`
409426 arbitrary kwargs passed to `pytensor.function`
410427
428+ .. warning:: `fn_kwargs` is deprecated and will be removed in future versions
429+
411430 Returns
412431 -------
413432 pytensor.function
414433 """
415- if fn_kwargs is None :
416- fn_kwargs = {}
434+ if fn_kwargs is not None :
435+ warnings .warn (
436+ "`fn_kwargs` is deprecated and will be removed in future versions. Use "
437+ "`compile_kwargs` instead" ,
438+ DeprecationWarning ,
439+ )
440+ compile_kwargs = fn_kwargs
441+
442+ if compile_kwargs is None :
443+ compile_kwargs = {}
417444 if not self .op .returns_loss :
418445 raise NotImplementedError (f"{ self .op } does not have loss" )
419446 if more_replacements is None :
420447 more_replacements = {}
421448 loss = self (sc_n_mc , more_replacements = more_replacements )
422449 seed = self .approx .rng .randint (2 ** 30 , dtype = np .int64 )
423- return compile ([], loss , random_seed = seed , ** fn_kwargs )
450+ return compile ([], loss , random_seed = seed , ** compile_kwargs )
424451
425452 @pytensor .config .change_flags (compute_test_value = "off" )
426453 def __call__ (self , nmc , ** kwargs ):
@@ -451,7 +478,7 @@ class Operator:
451478 require_logq = True
452479 objective_class = ObjectiveFunction
453480 supports_aevb = property (lambda self : not self .approx .any_histograms )
454- T = identity
481+ T = Elemwise ( scalar_identity )
455482
456483 def __init__ (self , approx ):
457484 self .approx = approx
0 commit comments