@@ -890,30 +890,9 @@ def connection_pattern(self, node):
890890 if self ._connection_pattern is not None :
891891 return self ._connection_pattern
892892
893- inp_len = len (self .inner_inputs )
894- out_len = len (self .inner_outputs )
895- cpmat_self = io_connection_pattern (self .inner_inputs , self .inner_outputs )
896-
897- lop_op = self .get_lop_op ()
898- cpmat_grad = io_connection_pattern (
899- lop_op .inner_inputs [inp_len :], lop_op .inner_outputs
900- )
901-
902- # cpmat_self |= cpmat_grad.T
903- # cpmat_self &= out_is_disconnected
904- for i , t in enumerate (self ._lop_op_stypes_l ):
905- if t is not None :
906- if isinstance (t .type , DisconnectedType ):
907- for o in range (out_len ):
908- cpmat_self [i ][o ] = False
909- for o in range (out_len ):
910- cpmat_self [i ][o ] |= cpmat_grad [o ][i ]
911-
912- # TODO in case DisconnectedType is implemented for R_op,
913- # self._rop_op_stypes_l self._rop_op should considered for
914- # connection_pattern
915-
916- return list (map (list , cpmat_self ))
893+ ret = io_connection_pattern (self .inner_inputs , self .inner_outputs )
894+ self ._connection_pattern = ret
895+ return ret
917896
918897 def infer_shape (self , fgraph , node , shapes ):
919898 # TODO: Use `fgraph.shape_feature` to do this instead.
0 commit comments