Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 44 additions & 40 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import builtins
import math
from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Any, TypeAlias
Expand Down Expand Up @@ -779,9 +778,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
This caches objects to save allocation and run time.

"""
if dtype not in cache:
cache[dtype] = ScalarType(dtype=dtype)
return cache[dtype]
try:
return cache[dtype]
except KeyError:
cache[dtype] = res = ScalarType(dtype=dtype)
return res


# Register C code for ViewOp on Scalars.
Expand Down Expand Up @@ -987,25 +988,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:


def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType
if isinstance(x, ScalarVariable):
return x

if isinstance(x, Variable):
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType

if isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")

if isinstance(x, Apply):
# FIXME: Why do we support calling this with Apply?
# Also, if we do, why can't we support multiple outputs?
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x, ScalarVariable):
return x
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
return as_scalar(x.outputs[0])

return constant(x)

Expand Down Expand Up @@ -1329,32 +1333,26 @@ def supports_c_code(self, inputs, outputs):
the given Elemwise inputs, outputs.

"""
try:
tmp_s_input = []
# To keep the same aliasing between inputs
mapping = dict()
for ii in inputs:
if ii in mapping:
tmp_s_input.append(mapping[ii])
else:
tmp = get_scalar_type(ii.dtype).make_variable()
tmp_s_input.append(tmp)
mapping[ii] = tmp_s_input[-1]

with config.change_flags(compute_test_value="ignore"):
s_op = self(*tmp_s_input, return_list=True)
tmp_s_input = []
# To keep the same aliasing between inputs
mapping = {}
for ii in inputs:
if ii in mapping:
tmp_s_input.append(mapping[ii])
else:
tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable()
tmp_s_input.append(tmp)

# if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the
# other ops.
try:
self.c_code(
s_op[0].owner,
self.make_node(*tmp_s_input),
"test_presence_of_c_code",
# FIXME: Shouldn't this be a unique name per unique variable?
["x" for x in inputs],
["z" for z in outputs],
{"fail": "%(fail)s"},
)
except (MethodNotDefined, NotImplementedError):
except (NotImplementedError, MethodNotDefined):
return False
return True

Expand Down Expand Up @@ -4094,12 +4092,12 @@ def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
super().__init__(*args, **kwargs)

def _cleanup_graph(self, inputs, outputs):
def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.

fgraph = FunctionGraph(copy(inputs), copy(outputs))
fgraph = FunctionGraph(inputs, outputs, clone=clone)

# Validate node types
for node in fgraph.apply_nodes:
Expand Down Expand Up @@ -4282,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):

init_param: tuple[str, ...] = ("inputs", "outputs")

def __init__(self, inputs, outputs, name="Composite"):
def __init__(
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
):
self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already
Expand All @@ -4300,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"):
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
# No inner Composite
inputs, outputs = clone(inputs, outputs)
if clone_graph:
inputs, outputs = clone(inputs, outputs)

else:
# Inner Composite that we need to flatten
# FIXME: There could be a composite in the middle of the graph, why is this here?
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
Expand All @@ -4322,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"):
assert res[0] != inputs
inputs, outputs = res[0], res2[1]

self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
# We already cloned the graph, or the user told us there was no need for it
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs)
Expand Down
Loading
Loading