Skip to content

Commit 820d99d

Browse files
committed
Speedup supports c_code
Not using `__call__` avoids the test_value computation
1 parent d3c5133 commit 820d99d

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

pytensor/scalar/basic.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,32 +1332,26 @@ def supports_c_code(self, inputs, outputs):
13321332
the given Elemwise inputs, outputs.
13331333
13341334
"""
1335-
try:
1336-
tmp_s_input = []
1337-
# To keep the same aliasing between inputs
1338-
mapping = dict()
1339-
for ii in inputs:
1340-
if ii in mapping:
1341-
tmp_s_input.append(mapping[ii])
1342-
else:
1343-
tmp = get_scalar_type(ii.dtype).make_variable()
1344-
tmp_s_input.append(tmp)
1345-
mapping[ii] = tmp_s_input[-1]
1346-
1347-
with config.change_flags(compute_test_value="ignore"):
1348-
s_op = self(*tmp_s_input, return_list=True)
1335+
tmp_s_input = []
1336+
# To keep the same aliasing between inputs
1337+
mapping = {}
1338+
for ii in inputs:
1339+
if ii in mapping:
1340+
tmp_s_input.append(mapping[ii])
1341+
else:
1342+
tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable()
1343+
tmp_s_input.append(tmp)
13491344

1350-
# if the scalar_op don't have a c implementation,
1351-
# we skip its fusion to allow the fusion of the
1352-
# other ops.
1345+
try:
13531346
self.c_code(
1354-
s_op[0].owner,
1347+
self.make_node(*tmp_s_input),
13551348
"test_presence_of_c_code",
1349+
# FIXME: Shouldn't this be a unique name per unique variable?
13561350
["x" for x in inputs],
13571351
["z" for z in outputs],
13581352
{"fail": "%(fail)s"},
13591353
)
1360-
except (MethodNotDefined, NotImplementedError):
1354+
except (NotImplementedError, MethodNotDefined):
13611355
return False
13621356
return True
13631357

0 commit comments

Comments
 (0)