Skip to content

Commit ed83911

Browse files
committed
Speedup supports c_code
Not using `__call__` avoids the test_value computation
1 parent 3048820 commit ed83911

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

pytensor/scalar/basic.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,32 +1333,26 @@ def supports_c_code(self, inputs, outputs):
13331333
the given Elemwise inputs, outputs.
13341334
13351335
"""
1336-
try:
1337-
tmp_s_input = []
1338-
# To keep the same aliasing between inputs
1339-
mapping = dict()
1340-
for ii in inputs:
1341-
if ii in mapping:
1342-
tmp_s_input.append(mapping[ii])
1343-
else:
1344-
tmp = get_scalar_type(ii.dtype).make_variable()
1345-
tmp_s_input.append(tmp)
1346-
mapping[ii] = tmp_s_input[-1]
1347-
1348-
with config.change_flags(compute_test_value="ignore"):
1349-
s_op = self(*tmp_s_input, return_list=True)
1336+
tmp_s_input = []
1337+
# To keep the same aliasing between inputs
1338+
mapping = {}
1339+
for ii in inputs:
1340+
if ii in mapping:
1341+
tmp_s_input.append(mapping[ii])
1342+
else:
1343+
tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable()
1344+
tmp_s_input.append(tmp)
13501345

1351-
# if the scalar_op don't have a c implementation,
1352-
# we skip its fusion to allow the fusion of the
1353-
# other ops.
1346+
try:
13541347
self.c_code(
1355-
s_op[0].owner,
1348+
self.make_node(*tmp_s_input),
13561349
"test_presence_of_c_code",
1350+
# FIXME: Shouldn't this be a unique name per unique variable?
13571351
["x" for x in inputs],
13581352
["z" for z in outputs],
13591353
{"fail": "%(fail)s"},
13601354
)
1361-
except (MethodNotDefined, NotImplementedError):
1355+
except (NotImplementedError, MethodNotDefined):
13621356
return False
13631357
return True
13641358

0 commit comments

Comments
 (0)