@@ -232,13 +232,13 @@ def remove_client(
232232 entry for `var` in `self.clients`.
233233
234234 """
235-
235+ clients = self . clients
236236 removal_stack = [(var , client_to_remove )]
237237 while removal_stack :
238238 var , client_to_remove = removal_stack .pop ()
239239
240240 try :
241- var_clients = self . clients [var ]
241+ var_clients = clients [var ]
242242 var_clients .remove (client_to_remove )
243243 except ValueError :
244244 # In this case, the original `var` could've been removed from
@@ -256,9 +256,7 @@ def remove_client(
256256 self .variables .remove (var )
257257 else :
258258 apply_node = var .owner
259- if not any (
260- output for output in apply_node .outputs if self .clients [output ]
261- ):
259+ if not any (clients [output ] for output in apply_node .outputs ):
262260 # The `Apply` node is not used and is not an output, so we
263261 # remove it and its outputs
264262 if not hasattr (apply_node .tag , "removed_by" ):
@@ -276,7 +274,7 @@ def remove_client(
276274 removal_stack .append ((in_var , (apply_node , i )))
277275
278276 if remove_if_empty :
279- del self . clients [var ]
277+ del clients [var ]
280278
281279 def import_var (
282280 self , var : Variable , reason : str | None = None , import_missing : bool = False
@@ -563,10 +561,11 @@ def remove_node(self, node: Apply, reason: str | None = None):
563561 node .tag .removed_by .append (str (reason ))
564562
565563 # Remove the outputs of the node (i.e. everything "below" it)
564+ clients = self .clients
566565 for out in node .outputs :
567566 self .variables .remove (out )
568567
569- out_clients = self . clients .get (out , ())
568+ out_clients = clients .get (out , ())
570569 while out_clients :
571570 out_client , out_idx = out_clients .pop ()
572571
@@ -590,13 +589,12 @@ def remove_node(self, node: Apply, reason: str | None = None):
590589 assert isinstance (out_client , Apply )
591590 self .remove_node (out_client , reason = reason )
592591
593- if out in self .clients :
594- del self .clients [out ]
592+ clients .pop (out , None )
595593
596594 # Remove all the arrows pointing to this `node`, and any orphaned
597595 # variables created by removing those arrows
598596 for inp_idx , inp in enumerate (node .inputs ):
599- inp_clients : list [ClientType ] = self . clients .get (inp , [])
597+ inp_clients : list [ClientType ] = clients .get (inp , [])
600598
601599 arrow = (node , inp_idx )
602600
@@ -810,12 +808,13 @@ def check_integrity(self) -> None:
810808 raise Exception (
811809 f"The following nodes are inappropriately cached:\n missing: { nodes_missing } \n in excess: { nodes_excess } "
812810 )
811+ clients = self .clients
813812 for node in nodes :
814813 for i , variable in enumerate (node .inputs ):
815- clients = self . clients [variable ]
816- if (node , i ) not in clients :
814+ var_clients = clients [variable ]
815+ if (node , i ) not in var_clients :
817816 raise Exception (
818- f"Inconsistent clients list { (node , i )} in { clients } "
817+ f"Inconsistent clients list { (node , i )} in { var_clients } "
819818 )
820819 variables = set (vars_between (self .inputs , self .outputs ))
821820 if set (self .variables ) != variables :
@@ -831,7 +830,7 @@ def check_integrity(self) -> None:
831830 and not isinstance (variable , AtomicVariable )
832831 ):
833832 raise Exception (f"Undeclared input: { variable } " )
834- for cl_node , i in self . clients [variable ]:
833+ for cl_node , i in clients [variable ]:
835834 if cl_node == "output" :
836835 if self .outputs [i ] is not variable :
837836 raise Exception (
0 commit comments