1010
1111from pytensor .configdefaults import config
1212from pytensor .graph .basic import Apply , Variable
13+ from pytensor .graph .fg import FunctionGraph , Output
1314from pytensor .graph .op import ComputeMapType , Op , StorageMapType , ThunkType
1415from pytensor .graph .type import HasDataType
1516from pytensor .graph .utils import MethodNotDefined
@@ -32,6 +33,30 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType:
3233 return res
3334
3435
36+ class SingleOpFunctionGraph (FunctionGraph ):
37+ """A `FunctionGraph` with a single `Apply` node.
38+
39+ This is used to compile a single `Apply` node with the C linker.
40+
41+ """
42+
43+ def __init__ (self , node : Apply , clone : bool = True ):
44+ if clone :
45+ node = node .clone_with_new_inputs ([i .clone () for i in node .inputs ])
46+ self .node = node
47+ self .apply_nodes = {node }
48+ self .inputs = inputs = node .inputs
49+ self .outputs = outputs = node .outputs
50+ self .variables = set (inputs ) | set (outputs )
51+ self .clients = {inp : [(node , idx )] for idx , inp in enumerate (inputs )}
52+ self .clients |= {
53+ out : [(Output (idx ).make_node (out ), 0 )] for idx , out in enumerate (outputs )
54+ }
55+
56+ def toposort (self ):
57+ return [self .node ]
58+
59+
3560class COp (Op , CLinkerOp ):
3661 """An `Op` with a C implementation."""
3762
@@ -51,12 +76,11 @@ def make_c_thunk(
5176 # The conclusion should be that the antire "make_c_thunk" method should be defined
5277 # in pytensor.link.c and dispatched onto the Op!
5378 import pytensor .link .c .basic
54- from pytensor .graph .fg import FunctionGraph
5579
5680 node_input_storage = [storage_map [r ] for r in node .inputs ]
5781 node_output_storage = [storage_map [r ] for r in node .outputs ]
5882
59- e = FunctionGraph (node . inputs , node . outputs )
83+ e = SingleOpFunctionGraph (node )
6084 e_no_recycling = [
6185 new_o
6286 for (new_o , old_o ) in zip (e .outputs , node .outputs , strict = True )
0 commit comments