1919 clone_replace ,
2020 graph_inputs ,
2121 io_connection_pattern ,
22- replace_nominals_with_dummies ,
2322)
2423from pytensor .graph .fg import FunctionGraph
2524from pytensor .graph .null_type import NullType
@@ -333,52 +332,51 @@ def __init__(
333332 if not (isinstance (inputs , list ) and isinstance (outputs , list )):
334333 raise TypeError ("Inputs and outputs must be lists" )
335334
336- for i in inputs + outputs :
337- if not isinstance (i , Variable ):
335+ for out in outputs :
336+ if not isinstance (out , Variable ):
338337 raise TypeError (
339- f"Inputs and outputs must be Variable instances; got { i } "
338+ f"Inputs and outputs must be Variable instances; got { out } "
340339 )
341- if i in inputs :
342- if isinstance (i , Constant ):
343- raise TypeError (f"Constants not allowed as inputs; { i } " )
344- if isinstance (i , SharedVariable ):
345- raise TypeError (f"SharedVariables not allowed as inputs; { i } " )
340+
341+ dummy_inputs = []
342+ for n , inp in enumerate (inputs ):
343+ if (
344+ not isinstance (inp , Variable )
345+ or isinstance (inp , Constant )
346+ or isinstance (inp , SharedVariable )
347+ ):
348+ raise TypeError (
349+ f"Inputs and outputs must be non-Constant/shared Variable instances; got { inp } "
350+ )
351+
352+ dummy_inputs .append (inp .type ())
346353
347354 if "updates" in kwargs or "givens" in kwargs :
348355 raise NotImplementedError ("Updates and givens are not supported" )
349356
350357 self .is_inline = inline
351358
359+ dummy_shared_inputs = []
352360 self .shared_inputs = []
353- inner_graph_inputs = graph_inputs (outputs , inputs )
354- for var in inner_graph_inputs :
361+ for var in graph_inputs (outputs , inputs ):
355362 if isinstance (var , SharedVariable ):
356363 # To correctly support shared variables the inner-graph should
357364 # not see them; otherwise, there will be problems with
358365 # gradients.
359366 # That's why we collect the shared variables and replace them
360367 # with dummies.
361368 self .shared_inputs .append (var )
369+ dummy_shared_inputs .append (var .type ())
362370 elif var not in inputs and not isinstance (var , Constant ):
363371 raise MissingInputError (f"OpFromGraph is missing an input: { var } " )
364372
365- inputs , outputs = replace_nominals_with_dummies (inputs , outputs )
366-
367- # The inputs should be `NominalVariable`s, so that graphs can be merged
368- replacements = {}
369- for n , v in enumerate (inputs ):
370- replacements [v ] = NominalVariable (n , v .type )
371-
372- shared_vars = [
373- NominalVariable (n , var .type )
374- for n , var in enumerate (self .shared_inputs , start = len (inputs ) + 1 )
375- ]
376-
377- replacements .update (dict (zip (self .shared_inputs , shared_vars )))
373+ replacements = dict (
374+ zip (inputs + self .shared_inputs , dummy_inputs + dummy_shared_inputs )
375+ )
378376
379377 new = rebuild_collect_shared (
380378 cast (Sequence [Variable ], outputs ),
381- inputs = inputs + shared_vars ,
379+ inputs = inputs + self . shared_inputs ,
382380 replace = replacements ,
383381 copy_inputs_over = False ,
384382 )
@@ -395,6 +393,21 @@ def __init__(
395393 assert not shared_inputs
396394
397395 self .fgraph = FunctionGraph (local_inputs , local_outputs , clone = False )
396+
397+ # The inputs need to be `NominalVariable`s so that we can merge
398+ # inner-graphs
399+ nominal_local_inputs = tuple (
400+ NominalVariable (n , var .type ) for n , var in enumerate (local_inputs )
401+ )
402+
403+ self .fgraph .replace_all (zip (local_inputs , nominal_local_inputs ))
404+
405+ for i , inp in enumerate (self .fgraph .inputs ):
406+ nom_inp = nominal_local_inputs [i ]
407+ self .fgraph .inputs [i ] = nom_inp
408+ self .fgraph .clients .pop (inp , None )
409+ self .fgraph .add_input (nom_inp )
410+
398411 self .kwargs = kwargs
399412 self .input_types = [inp .type for inp in inputs ]
400413 self .output_types = [out .type for out in outputs ]
@@ -417,6 +430,7 @@ def __init__(
417430 else :
418431 self .set_lop_overrides ("default" )
419432 self ._lop_type = "lop"
433+
420434 self .set_rop_overrides (rop_overrides )
421435
422436 self ._connection_pattern = connection_pattern
0 commit comments