Skip to content

Commit 538a5e7

Browse files
committed
Avoid FunctionGraph overhead when compiling single Ops to C
1 parent 6a99dca commit 538a5e7

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

pytensor/link/c/op.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply, Variable
13+
from pytensor.graph.fg import FunctionGraph, Output
1314
from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType
1415
from pytensor.graph.type import HasDataType
1516
from pytensor.graph.utils import MethodNotDefined
@@ -32,6 +33,30 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType:
3233
return res
3334

3435

36+
class SingleOpFunctionGraph(FunctionGraph):
37+
"""A `FunctionGraph` with a single `Apply` node.
38+
39+
This is used to compile a single `Apply` node with the C linker.
40+
41+
"""
42+
43+
def __init__(self, node: Apply, clone: bool = True):
44+
if clone:
45+
node = node.clone_with_new_inputs([i.clone() for i in node.inputs])
46+
self.node = node
47+
self.apply_nodes = {node}
48+
self.inputs = inputs = node.inputs
49+
self.outputs = outputs = node.outputs
50+
self.variables = set(inputs) | set(outputs)
51+
self.clients = {inp: [(node, idx)] for idx, inp in enumerate(inputs)}
52+
self.clients |= {
53+
out: [(Output(idx).make_node(out), 0)] for idx, out in enumerate(outputs)
54+
}
55+
56+
def toposort(self):
57+
return [self.node]
58+
59+
3560
class COp(Op, CLinkerOp):
3661
"""An `Op` with a C implementation."""
3762

@@ -51,12 +76,11 @@ def make_c_thunk(
5176
# The conclusion should be that the antire "make_c_thunk" method should be defined
5277
# in pytensor.link.c and dispatched onto the Op!
5378
import pytensor.link.c.basic
54-
from pytensor.graph.fg import FunctionGraph
5579

5680
node_input_storage = [storage_map[r] for r in node.inputs]
5781
node_output_storage = [storage_map[r] for r in node.outputs]
5882

59-
e = FunctionGraph(node.inputs, node.outputs)
83+
e = SingleOpFunctionGraph(node)
6084
e_no_recycling = [
6185
new_o
6286
for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True)

0 commit comments

Comments
 (0)