@@ -417,7 +417,10 @@ def __init__(
417417 FutureWarning ,
418418 )
419419 self ._lop_op_interface = False
420- self ._lop_op_cache : Callable | None = None
420+ # Dictionary where we cache OpFromGraph that represent the L_op
421+ # A distinct OpFromGraph is needed to represent each pattern of output_grads connection
422+ # It also returns a tuple that indicates which input_gradients are disconnected
423+ self ._lop_op_cache : dict [tuple [bool , ...], Callable ] = {}
421424 self ._rop_op_cache : Callable | None = None
422425
423426 self ._connection_pattern = connection_pattern
@@ -480,24 +483,30 @@ def _call_custom_override(self, op_overrides, callable_args, nout):
480483 return outputs
481484
482485 @config .change_flags (compute_test_value = "off" )
483- def _build_and_cache_lop_op (self ) -> Callable :
484- """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
486+ def _build_and_cache_lop_op (
487+ self , disconnected_output_grads : tuple [bool , ...]
488+ ) -> Callable :
489+ """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
490+ specialized for the pattern of disconnected_output_grads
485491
486492 Results are cached in self._lop_op_cache
487493 """
488- if self ._lop_op_cache is not None :
489- return self ._lop_op_cache
494+ try :
495+ return self ._lop_op_cache [disconnected_output_grads ]
496+ except KeyError :
497+ pass
490498
491499 inner_inputs = self .inner_inputs
492500 inner_outputs = self .inner_outputs
493501 nin = len (inner_inputs )
502+ nout = len (inner_outputs )
494503 lop_overrides = (
495504 self .lop_overrides if self ._lop_op_interface else self .grad_overrides
496505 )
497506
498507 if isinstance (lop_overrides , OpFromGraph ):
499508 if self ._lop_op_interface :
500- self ._lop_op_cache = lop_overrides
509+ self ._lop_op_cache [ disconnected_output_grads ] = lop_overrides
501510 lop_overrides .kwargs ["on_unused_input" ] = "ignore"
502511 return lop_overrides
503512
@@ -507,20 +516,42 @@ def _build_and_cache_lop_op(self) -> Callable:
507516 def lop_overrides (inps , grads ):
508517 return self .grad_overrides (* inps , * grads )
509518
510- output_grads = [out_t () for out_t in self .output_types ]
519+ # We try to compute the gradient with respect to connected outputs only
520+ connected_inner_outputs = [
521+ # We add an identity operation(copy) so that we don't override indirect
522+ # gradient contributions to an inner output coming from other inner outputs
523+ inner_out .copy ()
524+ for inner_out , disconnected in zip (
525+ inner_outputs , disconnected_output_grads , strict = True
526+ )
527+ if not disconnected
528+ ]
529+ connected_output_grads = [
530+ out_t ()
531+ for out_t , disconnected in zip (
532+ self .output_types , disconnected_output_grads , strict = True
533+ )
534+ if not disconnected
535+ ]
511536 fn_grad = partial (
512537 grad ,
513538 cost = None ,
514539 disconnected_inputs = "ignore" ,
515540 return_disconnected = "disconnected" ,
516541 null_gradients = "return" ,
517- known_grads = dict (zip (inner_outputs , output_grads )),
542+ known_grads = dict (
543+ zip (connected_inner_outputs , connected_output_grads , strict = True )
544+ ),
518545 )
519546
520547 if self ._lop_op_interface :
521- callable_args = (inner_inputs , inner_outputs , output_grads )
548+ callable_args = (
549+ inner_inputs ,
550+ connected_inner_outputs ,
551+ connected_output_grads ,
552+ )
522553 else :
523- callable_args = (inner_inputs , output_grads )
554+ callable_args = (inner_inputs , connected_output_grads )
524555
525556 # we need to convert _lop_op into an OfG instance
526557 if lop_overrides is None :
@@ -544,32 +575,51 @@ def lop_overrides(inps, grads):
544575 else :
545576 input_grads = self ._call_custom_override (lop_overrides , callable_args , nin )
546577
547- # Filter out disconnected input and output gradients
578+ # Filter out disconnected/null input generated from the inner graph grad
579+ # We append them in the outer wrapper function below
548580 connected_input_grads = [
549581 inp_grad
550582 for inp_grad in input_grads
551583 if not isinstance (inp_grad .type , DisconnectedType | NullType )
552584 ]
553585 lop_op = type (self )(
554- inputs = inner_inputs + inner_outputs + output_grads ,
586+ inputs = inner_inputs + connected_inner_outputs + connected_output_grads ,
555587 outputs = connected_input_grads ,
556588 inline = self .is_inline ,
557589 name = (None if self .name is None else f"{ self .name } _LOp" ),
558590 # TODO: We can be eager here and exclude unused inputs in the OFG
559591 on_unused_input = "ignore" ,
560592 )
561593
562- # Return a wrapper that combines connected and disconnected input gradients
594+ # Return a wrapper that combines connected and disconnected/null input gradients
595+ # And also filters out disconnected/null output gradients
563596 def wrapper (* inputs : Variable , ** kwargs ) -> list [Variable ]:
564- connected_input_grads = iter (lop_op (* inputs , ** kwargs ))
597+ inputs , outputs , output_grads = (
598+ inputs [: - nout * 2 ],
599+ inputs [- nout * 2 : - nout ],
600+ inputs [- nout :],
601+ )
602+ connected_outputs = [
603+ output
604+ for output , output_grad in zip (outputs , output_grads , strict = True )
605+ if not isinstance (output_grad .type , DisconnectedType | NullType )
606+ ]
607+ connected_output_grads = [
608+ output_grad
609+ for output_grad in output_grads
610+ if not isinstance (output_grad .type , DisconnectedType )
611+ ]
612+ connected_input_grads = iter (
613+ lop_op (* inputs , * connected_outputs , * connected_output_grads , ** kwargs )
614+ )
565615 return [
566616 input_grad
567617 if isinstance (input_grad .type , DisconnectedType | NullType )
568618 else next (connected_input_grads )
569619 for input_grad in input_grads
570620 ]
571621
572- self ._lop_op_cache = wrapper
622+ self ._lop_op_cache [ disconnected_output_grads ] = wrapper
573623 return wrapper
574624
575625 @config .change_flags (compute_test_value = "off" )
@@ -652,7 +702,10 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
652702 return wrapper
653703
654704 def L_op (self , inputs , outputs , output_grads ):
655- lop_op = self ._build_and_cache_lop_op ()
705+ disconnected_output_grads = tuple (
706+ isinstance (og .type , DisconnectedType ) for og in output_grads
707+ )
708+ lop_op = self ._build_and_cache_lop_op (disconnected_output_grads )
656709 return lop_op (* inputs , * outputs , * output_grads , return_list = True )
657710
658711 def R_op (self , inputs , eval_points ):
0 commit comments