2424 vars_between ,
2525)
2626from pytensor .graph .utils import MetaObject , MissingInputError , TestValueError
27- from pytensor .misc .ordered_set import OrderedSet
2827
2928
3029ClientType = tuple [Apply , int ]
@@ -133,7 +132,6 @@ def __init__(
133132 features = []
134133
135134 self ._features : list [Feature ] = []
136-
137135 # All apply nodes in the subgraph defined by inputs and
138136 # outputs are cached in this field
139137 self .apply_nodes : set [Apply ] = set ()
@@ -161,7 +159,8 @@ def __init__(
161159 "input's owner or use graph.clone."
162160 )
163161
164- self .add_input (in_var , check = False )
162+ self .inputs .append (in_var )
163+ self .clients .setdefault (in_var , [])
165164
166165 for output in outputs :
167166 self .add_output (output , reason = "init" )
@@ -189,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None:
189188 return
190189
191190 self .inputs .append (var )
192- self .setup_var (var )
193-
194- def setup_var (self , var : Variable ) -> None :
195- """Set up a variable so it belongs to this `FunctionGraph`.
196-
197- Parameters
198- ----------
199- var : pytensor.graph.basic.Variable
200-
201- """
202191 self .clients .setdefault (var , [])
203192
204193 def get_clients (self , var : Variable ) -> list [ClientType ]:
@@ -322,10 +311,11 @@ def import_var(
322311
323312 """
324313 # Imports the owners of the variables
325- if var .owner and var .owner not in self .apply_nodes :
326- self .import_node (var .owner , reason = reason , import_missing = import_missing )
314+ apply = var .owner
315+ if apply is not None and apply not in self .apply_nodes :
316+ self .import_node (apply , reason = reason , import_missing = import_missing )
327317 elif (
328- var . owner is None
318+ apply is None
329319 and not isinstance (var , AtomicVariable )
330320 and var not in self .inputs
331321 ):
@@ -336,10 +326,11 @@ def import_var(
336326 f"Computation graph contains a NaN. { var .type .why_null } "
337327 )
338328 if import_missing :
339- self .add_input (var )
329+ self .inputs .append (var )
330+ self .clients .setdefault (var , [])
340331 else :
341332 raise MissingInputError (f"Undeclared input: { var } " , variable = var )
342- self .setup_var (var )
333+ self .clients . setdefault (var , [] )
343334 self .variables .add (var )
344335
345336 def import_node (
@@ -356,29 +347,29 @@ def import_node(
356347 apply_node : Apply
357348 The node to be imported.
358349 check : bool
359- Check that the inputs for the imported nodes are also present in
360- the `FunctionGraph`.
350+ Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
361351 reason : str
362352 The name of the optimization or operation in progress.
363353 import_missing : bool
364354 Add missing inputs instead of raising an exception.
365355 """
366356 # We import the nodes in topological order. We only are interested in
367- # new nodes, so we use all variables we know of as if they were the
368- # input set. (The functions in the graph module only use the input set
369- # to know where to stop going down.)
370- new_nodes = tuple ( toposort ( apply_node . outputs , blockers = self .variables ))
371-
372- if check :
373- for node in new_nodes :
357+ # new nodes, so we use all nodes we know of as inputs to interrupt the toposort
358+ self_variables = self . variables
359+ self_clients = self . clients
360+ self_apply_nodes = self .apply_nodes
361+ self_inputs = self . inputs
362+ for node in toposort ( apply_node . outputs , blockers = self_variables ) :
363+ if check :
374364 for var in node .inputs :
375365 if (
376366 var .owner is None
377367 and not isinstance (var , AtomicVariable )
378- and var not in self . inputs
368+ and var not in self_inputs
379369 ):
380370 if import_missing :
381- self .add_input (var )
371+ self_inputs .append (var )
372+ self_clients .setdefault (var , [])
382373 else :
383374 error_msg = (
384375 f"Input { node .inputs .index (var )} ({ var } )"
@@ -390,20 +381,20 @@ def import_node(
390381 )
391382 raise MissingInputError (error_msg , variable = var )
392383
393- for node in new_nodes :
394- assert node not in self . apply_nodes
395- self . apply_nodes . add ( node )
396- if not hasattr ( node . tag , " imported_by" ):
397- node . tag . imported_by = []
398- node . tag .imported_by .append (str (reason ))
384+ self_apply_nodes . add ( node )
385+ tag = node . tag
386+ if not hasattr ( tag , "imported_by" ):
387+ tag . imported_by = [ str ( reason )]
388+ else :
389+ tag .imported_by .append (str (reason ))
399390 for output in node .outputs :
400- self . setup_var (output )
401- self . variables .add (output )
402- for i , input in enumerate (node .inputs ):
403- if input not in self . variables :
404- self . setup_var ( input )
405- self . variables . add (input )
406- self . add_client ( input , (node , i ))
391+ self_clients . setdefault (output , [] )
392+ self_variables .add (output )
393+ for i , inp in enumerate (node .inputs ):
394+ if inp not in self_variables :
395+ self_clients . setdefault ( inp , [] )
396+ self_variables . add (inp )
397+ self_clients [ inp ]. append ( (node , i ))
407398 self .execute_callbacks ("on_import" , node , reason )
408399
409400 def change_node_input (
@@ -457,7 +448,7 @@ def change_node_input(
457448 self .outputs [node .op .idx ] = new_var
458449
459450 self .import_var (new_var , reason = reason , import_missing = import_missing )
460- self .add_client ( new_var , (node , i ))
451+ self .clients [ new_var ]. append ( (node , i ))
461452 self .remove_client (r , (node , i ), reason = reason )
462453 # Precondition: the substitution is semantically valid However it may
463454 # introduce cycles to the graph, in which case the transaction will be
@@ -756,10 +747,6 @@ def toposort(self) -> list[Apply]:
756747 :meth:`FunctionGraph.orderings`.
757748
758749 """
759- if len (self .apply_nodes ) < 2 :
760- # No sorting is necessary
761- return list (self .apply_nodes )
762-
763750 return list (toposort_with_orderings (self .outputs , orderings = self .orderings ()))
764751
765752 def orderings (self ) -> dict [Apply , list [Apply ]]:
@@ -779,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]:
779766 take care of computing the dependencies by itself.
780767
781768 """
782- assert isinstance (self ._features , list )
783- all_orderings : list [dict ] = []
769+ all_orderings : list [dict ] = [
770+ orderings
771+ for feature in self ._features
772+ if (
773+ hasattr (feature , "orderings" ) and (orderings := feature .orderings (self ))
774+ )
775+ ]
784776
785- for feature in self ._features :
786- if hasattr (feature , "orderings" ):
787- orderings = feature .orderings (self )
788- if not isinstance (orderings , dict ):
789- raise TypeError (
790- "Non-deterministic return value from "
791- + str (feature .orderings )
792- + ". Nondeterministic object is "
793- + str (orderings )
794- )
795- if len (orderings ) > 0 :
796- all_orderings .append (orderings )
797- for node , prereqs in orderings .items ():
798- if not isinstance (prereqs , list | OrderedSet ):
799- raise TypeError (
800- "prereqs must be a type with a "
801- "deterministic iteration order, or toposort "
802- " will be non-deterministic."
803- )
804- if len (all_orderings ) == 1 :
777+ if not all_orderings :
778+ return {}
779+ elif len (all_orderings ) == 1 :
805780 # If there is only 1 ordering, we reuse it directly.
806781 return all_orderings [0 ].copy ()
807782 else :
0 commit comments