22
33import warnings
44from collections import OrderedDict
5- from collections .abc import Sequence
5+ from collections .abc import Callable , Sequence
66from copy import copy
77from functools import partial
8- from typing import cast
8+ from typing import Union , cast
99
1010import pytensor .tensor as pt
1111from pytensor .compile .function import function
@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
225225 e2 = op(x, y, z) + op(z, y, x)
226226 fn = function([x, y, z], [e2])
227227
228- Example 3 override L_op
228+ Example 3 override second output of L_op
229229
230230 .. code-block:: python
231231
@@ -241,7 +241,7 @@ def rescale_dy(inps, outputs, out_grads):
241241 op = OpFromGraph(
242242 [x, y, z],
243243 [e],
244- lop_overrides=['default' , rescale_dy, 'default' ],
244+ lop_overrides=[None , rescale_dy, None ],
245245 )
246246 e2 = op(x, y, z)
247247 dx, dy, dz = grad(e2, [x, y, z])
@@ -253,7 +253,7 @@ def rescale_dy(inps, outputs, out_grads):
253253
254254 TYPE_ERR_MSG = (
255255 "L_op/gradient override should be (single or list of)"
256- "'default' | OpFromGraph | callable | Variable "
256+ "None | OpFromGraph | callable | Variable "
257257 "with NullType or DisconnectedType, got %s"
258258 )
259259 STYPE_ERR_MSG = (
@@ -308,9 +308,9 @@ def __init__(
308308 outputs : list [Variable ],
309309 * ,
310310 inline : bool = False ,
311- lop_overrides : str = "default" ,
312- grad_overrides : str = "default" ,
313- rop_overrides : str = "default" ,
311+ lop_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
312+ grad_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
313+ rop_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
314314 connection_pattern : list [list [bool ]] | None = None ,
315315 strict : bool = False ,
316316 name : str | None = None ,
@@ -333,10 +333,10 @@ def __init__(
333333
334334 ``False`` : will use a pre-compiled function inside.
335335 grad_overrides
336- Defaults to ``'default' ``.
336+ Defaults to ``None ``.
337337 This argument is mutually exclusive with ``lop_overrides``.
338338
339- ``'default' `` : Do not override, use default grad() result
339+ ``None `` : Do not override, use default grad() result
340340
341341 `OpFromGraph`: Override with another `OpFromGraph`, should
342342 accept inputs as the same order and types of ``inputs`` and ``output_grads``
@@ -346,14 +346,14 @@ def __init__(
346346 Each argument is expected to be a list of :class:`Variable `.
347347 Must return list of :class:`Variable `.
348348 lop_overrides
349- Defaults to ``'default' ``.
349+ Defaults to ``None ``.
350350
351351 This argument is mutually exclusive with ``grad_overrides``.
352352
353353 These options are similar to the ``grad_overrides`` above, but for
354354 the :meth:`Op.L_op` method.
355355
356- ``'default' ``: Do not override, use the default :meth:`Op.L_op` result
356+ ``None ``: Do not override, use the default :meth:`Op.L_op` result
357357
358358 `OpFromGraph`: Override with another `OpFromGraph`, should
359359 accept inputs as the same order and types of ``inputs``,
@@ -373,11 +373,11 @@ def __init__(
373373 a specific input, length of list must be equal to number of inputs.
374374
375375 rop_overrides
376- One of ``{'default' , OpFromGraph, callable, Variable}``.
376+ One of ``{None , OpFromGraph, callable, Variable}``.
377377
378- Defaults to ``'default' ``.
378+ Defaults to ``None ``.
379379
380- ``'default' ``: Do not override, use the default :meth:`Op.R_op` result
380+ ``None ``: Do not override, use the default :meth:`Op.R_op` result
381381
382382 `OpFromGraph`: Override with another `OpFromGraph`, should
383383 accept inputs as the same order and types of ``inputs`` and ``eval_points``
@@ -446,27 +446,37 @@ def __init__(
446446 self .input_types = [inp .type for inp in inputs ]
447447 self .output_types = [out .type for out in outputs ]
448448
449+ for override in (lop_overrides , grad_overrides , rop_overrides ):
450+ if override == "default" :
451+ raise ValueError (
452+ "'default' is no longer a valid value for overrides. Use None instead."
453+ )
454+ if isinstance (override , Variable ):
455+ raise TypeError (
456+ "Variables are no longer valid types for overrides. Return them in a list for each output instead"
457+ )
458+
449459 self .lop_overrides = lop_overrides
450460 self .grad_overrides = grad_overrides
451461 self .rop_overrides = rop_overrides
452462
453- if lop_overrides != "default" :
454- if grad_overrides != "default" :
463+ if lop_overrides is not None :
464+ if grad_overrides is not None :
455465 raise ValueError (
456466 "lop_overrides and grad_overrides are mutually exclusive"
457467 )
458468 else :
459469 self .set_lop_overrides (lop_overrides )
460470 self ._lop_type = "lop"
461- elif grad_overrides != "default" :
471+ elif grad_overrides is not None :
462472 warnings .warn (
463473 "grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future." ,
464474 FutureWarning ,
465475 )
466476 self .set_lop_overrides (grad_overrides )
467477 self ._lop_type = "grad"
468478 else :
469- self .set_lop_overrides ("default" )
479+ self .set_lop_overrides (None )
470480 self ._lop_type = "lop"
471481
472482 self .set_rop_overrides (rop_overrides )
@@ -546,7 +556,7 @@ def lop_op(inps, grads):
546556 callable_args = (local_inputs , output_grads )
547557
548558 # we need to convert _lop_op into an OfG instance
549- if lop_op == "default" :
559+ if lop_op is None :
550560 gdefaults_l = fn_grad (wrt = local_inputs )
551561 all_grads_l , all_grads_ov_l = zip (
552562 * [
@@ -556,12 +566,6 @@ def lop_op(inps, grads):
556566 )
557567 all_grads_l = list (all_grads_l )
558568 all_grads_ov_l = list (all_grads_ov_l )
559- elif isinstance (lop_op , Variable ):
560- if isinstance (lop_op .type , DisconnectedType | NullType ):
561- all_grads_l = [inp .zeros_like () for inp in local_inputs ]
562- all_grads_ov_l = [lop_op .type () for _ in range (inp_len )]
563- else :
564- raise ValueError (self .STYPE_ERR_MSG % lop_op .type )
565569 elif isinstance (lop_op , list ):
566570 goverrides_l = lop_op
567571 if len (goverrides_l ) != inp_len :
@@ -571,15 +575,13 @@ def lop_op(inps, grads):
571575 )
572576 # compute non-overriding downsteam grads from upstreams grads
573577 # it's normal some input may be disconnected, thus the 'ignore'
574- wrt_l = [
575- lin for lin , gov in zip (local_inputs , goverrides_l ) if gov == "default"
576- ]
578+ wrt_l = [lin for lin , gov in zip (local_inputs , goverrides_l ) if gov is None ]
577579 gdefaults = iter (fn_grad (wrt = wrt_l ) if wrt_l else [])
578580 # combine overriding gradients
579581 all_grads_l = []
580582 all_grads_ov_l = []
581583 for inp , fn_gov in zip (local_inputs , goverrides_l ):
582- if fn_gov == "default" :
584+ if fn_gov is None :
583585 gnext , gnext_ov = OpFromGraph ._filter_grad_var (next (gdefaults ), inp )
584586 all_grads_l .append (gnext )
585587 all_grads_ov_l .append (gnext_ov )
@@ -652,13 +654,13 @@ def _recompute_rop_op(self):
652654 fn_rop = partial (Rop , wrt = local_inputs , eval_points = eval_points )
653655 TYPE_ERR_MSG = (
654656 "R_op overrides should be (single or list of)"
655- "OpFromGraph | 'default' | None | 0 | callable, got %s"
657+ "OpFromGraph, None, a list or a callable, got %s"
656658 )
657659 STYPE_ERR_MSG = (
658660 "Overriding Variable instance can only have type"
659661 " of DisconnectedType or NullType, got %s"
660662 )
661- if rop_op == "default" :
663+ if rop_op is None :
662664 rdefaults_l = fn_rop (f = local_outputs )
663665 all_rops_l , all_rops_ov_l = zip (
664666 * [
@@ -668,15 +670,6 @@ def _recompute_rop_op(self):
668670 )
669671 all_rops_l = list (all_rops_l )
670672 all_rops_ov_l = list (all_rops_ov_l )
671- elif isinstance (rop_op , Variable ):
672- if isinstance (rop_op .type , NullType ):
673- all_rops_l = [inp .zeros_like () for inp in local_inputs ]
674- all_rops_ov_l = [rop_op .type () for _ in range (out_len )]
675- elif isinstance (rop_op .type , DisconnectedType ):
676- all_rops_l = [inp .zeros_like () for inp in local_inputs ]
677- all_rops_ov_l = [None ] * out_len
678- else :
679- raise ValueError (STYPE_ERR_MSG % rop_op .type )
680673 elif isinstance (rop_op , list ):
681674 roverrides_l = rop_op
682675 if len (roverrides_l ) != out_len :
@@ -686,15 +679,15 @@ def _recompute_rop_op(self):
686679 )
687680 # get outputs that does not have Rop override
688681 odefaults_l = [
689- lo for lo , rov in zip (local_outputs , roverrides_l ) if rov == "default"
682+ lo for lo , rov in zip (local_outputs , roverrides_l ) if rov is None
690683 ]
691684 rdefaults_l = fn_rop (f = odefaults_l )
692685 rdefaults = iter (rdefaults_l if odefaults_l else [])
693686 # combine overriding Rops
694687 all_rops_l = []
695688 all_rops_ov_l = []
696689 for out , fn_rov in zip (local_outputs , roverrides_l ):
697- if fn_rov == "default" :
690+ if fn_rov is None :
698691 rnext , rnext_ov = OpFromGraph ._filter_rop_var (next (rdefaults ), out )
699692 all_rops_l .append (rnext )
700693 all_rops_ov_l .append (rnext_ov )
@@ -769,7 +762,6 @@ def set_grad_overrides(self, grad_overrides):
769762 self ._lop_op = grad_overrides
770763 self ._lop_op_is_cached = False
771764 self ._lop_type = "grad"
772- self ._lop_is_default = grad_overrides == "default"
773765
774766 def set_lop_overrides (self , lop_overrides ):
775767 """
@@ -780,7 +772,6 @@ def set_lop_overrides(self, lop_overrides):
780772 self ._lop_op = lop_overrides
781773 self ._lop_op_is_cached = False
782774 self ._lop_type = "lop"
783- self ._lop_is_default = lop_overrides == "default"
784775
785776 def set_rop_overrides (self , rop_overrides ):
786777 """
@@ -790,7 +781,6 @@ def set_rop_overrides(self, rop_overrides):
790781 """
791782 self ._rop_op = rop_overrides
792783 self ._rop_op_is_cached = False
793- self ._rop_is_default = rop_overrides == "default"
794784
795785 def L_op (self , inputs , outputs , output_grads ):
796786 if not self ._lop_op_is_cached :
0 commit comments