diff --git a/onnxscript/converter.py b/onnxscript/_converter.py similarity index 65% rename from onnxscript/converter.py rename to onnxscript/_converter.py index dfcddefbd3..dbce5a1c01 100644 --- a/onnxscript/converter.py +++ b/onnxscript/_converter.py @@ -1,41 +1,45 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Python-to-IR converter""" + from __future__ import annotations import ast +from collections import defaultdict +import dataclasses import logging from typing import ( TYPE_CHECKING, Any, Dict, List, + Mapping, NoReturn, Optional, Sequence, Tuple, Union, + _GenericAlias ) -import onnx +import onnx_ir as ir +from onnxscript.ir import _schemas import onnxscript -from onnxscript import irbuilder, onnx_types, sourceinfo, values +from onnxscript import onnx_types, sourceinfo, values from onnxscript import type_annotation as ta -from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation - -logger = logging.getLogger("onnxscript") +from onnxscript._internal import _analysis, ast_utils, autocast -# Python-to-IR converter: +logger = logging.getLogger(__name__) def not_allowed(construct): return f"{construct}not supported." -class TranslationError(Exception): - def __init__(self, *args: object) -> None: - super().__init__(*args) +class TranslationError(RuntimeError): + pass def warn(msg): @@ -57,7 +61,7 @@ def ignore(cond, msg): # map from python operators to ONNX ops -primop_map = { +_PRIMOP_MAP = { ast.Add: "Add", ast.And: "And", ast.BitAnd: "And", @@ -80,160 +84,205 @@ def ignore(cond, msg): } -class Variable: - """Represents an ONNX variable. +_CASTABLE_FIELD = "pkg.onnxscript.converter.castable" +_SOURCEINFO_FIELD = "pkg.onnxscript.sourceinfo" - TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this - converter. - """ - def __init__(self, name: str, castable: bool = False): - """Initialize the instance. - Args: - name: Name of the ONNX variable - castable: Whether this variable is castable to a desired target type. - Used for ONNX variables representing constants created from python values - like 0 or 1 or 0.5 which are treated as polymorphic values castable to other - types as needed. - """ - self.name = name - self.is_castable = castable +class DynamicKind(IntFlag): + Unknown = 0 + Input = 1 + Output = 2 + Intermediate = 4 + Loop = 8 + +# The type-alias LocalSymValue represents the types of values that local names in a +# script-function may be bound to during translation, (ONNX IR values). +# TODO(rama): Rationalize this and values.SymbolValue + +LocalSymValue = Union[ir.Value, ir.Attr, ir.Function] - def __str__(self) -> str: - return self.name +# The type-alias PyValue is used to represent the types of python values that may be used +# in an ONNX Script function. +# TODO(rama): Flesh out the set of valid types here. These include values such as +# 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for +# use as value-parameters or attribute-parameters in an ONNX call (Node). +PyValue = Union[int, float, str, bool, Sequence[int], Sequence[float], Sequence[str], Sequence[bool]] -if TYPE_CHECKING: - # The type-alias LocalSymValue represents the types of values that local names in a - # script-function may be bound to during translation, (ONNX IR values). - # TODO(rama): Rationalize this and values.SymbolValue +# The type-alias SymValue denotes values that an identifier may be bound to during +# translation. A local name will be bound to a LocalSymValue, while a global name +# will be bound to a PyValue. - LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction] +SymValue = Union[LocalSymValue, PyValue] - # The type-alias PyValue is used to represent the types of python values that may be used - # in an ONNX Script function. - # TODO(rama): Flesh out the set of valid types here. These include values such as - # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for - # use as value-parameters or attribute-parameters in an ONNX call (Node). +# PreferredName is a type-alias used to represent the preferred name used in the generated +# ONNX for a value returned by an expression. There is no guarantee that the specified +# name will be used exactly. The converter will modify the name (with a suffix), +# if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). - PyValue = Any +PreferredName = str - # The type-alias SymValue denotes values that an identifier may be bound to during - # translation. A local name will be bound to a LocalSymValue, while a global name - # will be bound to a PyValue. +# The type-alias OnnxVar indicates variable names used in the generated ONNX. +OnnxVarName = str - SymValue = Union[LocalSymValue, PyValue] - # PreferredName is a type-alias used to represent the preferred name used in the generated - # ONNX for a value returned by an expression. There is no guarantee that the specified - # name will be used exactly. The converter will modify the name (with a suffix), - # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). +def mark_castable(value: ir.Value): + """Mark an ONNX value as auto-castable.""" + value.meta[_CASTABLE_FIELD] = True - PreferredName = str +def set_sourceinfo(value: ir.Value, info: sourceinfo.SourceInfo): + """Set the source information for an ONNX value.""" + value.meta[_SOURCEINFO_FIELD] = info - # The type-alias OnnxVar indicates variable names used in the generated ONNX. - OnnxVarName = str + +def is_base_type_bool(attr: ir.Attr) -> bool: + """Check if the attribute is a boolean type.""" + # FIXME: Add meta to attributes + attr.meta[_SOURCEINFO_FIELD] + +@dataclasses.dataclass +class ASTMeta: + """Metadata for an AST node. + + This class is used to store metadata about an AST node. + """ + + # For liveness analysis, + live_out: set[str] | None = None + live_in: set[str] | None = None + + +class _ValueEnvironment: + def __init__(self, converter: Converter): + self._py_var_name_to_ir_values: dict[str, ir.Value] = {} + self._py_var_name_to_ir_attr_refs: dict[str, ir.Attr] = {} + self._py_var_name_to_py_values: dict[str, PyValue] = {} + self._converter = converter + + def get_or_create_value( + self, var: str, info: sourceinfo.SourceInfo + ) -> ir.Value: + """Get or create an IR value from Python variable name.""" + if var in self._py_var_name_to_ir_values: + return self._py_var_name_to_ir_values[var] + if var in self._py_var_name_to_ir_attr_refs: + # promote attribute to value + attr = self._py_var_name_to_ir_attr_refs[var] + result = self._converter.op( + "Constant", [], attrs=[attr] + ) + if is_base_type_bool(attr): + # ONNX attributes use an int-encoding for bools, but ONNX tensor types + # distinguish between int and bool. So we cast the int tensor to a bool tensor, + # to promote a (python) bool attribute to a ONNX bool tensor. + result = self._converter.op( + "Cast", + [result], + attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], + ) + + self._py_var_name_to_ir_values[var] = result + return result + if var in self._py_var_name_to_py_values: + # Assume value is a python-value convertible to a tensor + result = self._converter.op( + "Constant", [], attrs=[ir.AttrTensor("value", ir.tensor(var, name=var))] + ) + mark_castable(result) + self._py_var_name_to_ir_values[var] = result + + # TODO(justinchuby): Update error message + raise ValueError(f"Variable '{var}' is unbound.") class Converter: """Main class to translate python code into ONNX operators. - Args: - ir_builder: convert AST node into ONNX structures, if None, - class :class:`onnxscript.irbuilder.IRBuilder` is used + The converter translates a Python function into an ONNX function by + traversing the Python AST of the function and generating ONNX nodes + that represent the operations in the Python code. + + ..tip:: - The class uses logger `onnxscript`. Logging can be enabled with the following code: + The class uses logger `onnxscript`. Logging can be enabled with the following code: - :: + :: - import logging - logging.basicConfig(level=logging.DEBUG) + import logging + logging.basicConfig(level=logging.DEBUG) - Or if you need to enable only the logger used by this module: + Or if you need to enable only the logger used by this module: - :: + :: - import logging - logger = logging.getLogger('onnxscript') - logger.setLevel(logging.DEBUG) - console = logging.StreamHandler() - logger.addHandler(console) + import logging + logger = logging.getLogger('onnxscript') + logger.setLevel(logging.DEBUG) + console = logging.StreamHandler() + logger.addHandler(console) """ def __init__( self, - ir_builder: Optional[irbuilder.IRBuilder] = None, + root: ast.FunctionDef, + *, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self.ir_builder = ir_builder or irbuilder.IRBuilder() - self.source = source + """Initialize the converter. + + Args: + root: The root AST node of the function to be converted. + opset: The ONNX opset to use for the conversion. If None, the default opset is used. + global_names: A dictionary of global names available in the script. + source: Optional source code string for error reporting. + default_opset: The default ONNX opset to use if no ONNX opset is specified in the script. + """ + if not isinstance(root, ast.FunctionDef): + raise TypeError(f"Converter expects an AST FunctionDef node, got {type(root)}.") + self._ast_root = root + self._opset = opset + if global_names is not None: # We make a copy in case function eval modifies it. - self.globals = global_names.copy() - self.this_module = opset - self.default_opset_ = default_opset - - # States initialized by `_init_function_translation` - self._outer: List[irbuilder.IRFunction] = [] - self._current_fn: irbuilder.IRFunction = None - self._nextvar: int = 0 - self._used_vars: set[str] = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._globals = global_names.copy() + else: + self._globals = {} - @property - def default_opset(self) -> values.Opset: - if self.default_opset_ is None: - raise RuntimeError( + self._source = source + self._default_opset = default_opset or _find_onnx_opset(root, self._globals) + if self._default_opset is None: + raise ValueError( "default_opset must be specified in script for functions " "that do not contain any use of an ONNX opset." ) - return self.default_opset_ - - def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: - if opset.domain != "": - return - if self.default_opset_ is not None: - if ( - opset.domain != self.default_opset_.domain - or opset.version != self.default_opset_.version - ): - self.fail( - node, f"Two distincts opset were used ({opset} != {self.default_opset_})." - ) - else: - self.default_opset_ = opset - - def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: - """Find the (first) ONNX opset used in the function, if any.""" - # Search for a Call expression of form "op.OpName(...)" - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Attribute): - opset_expr = node.func.value - if isinstance(opset_expr, ast.Name): - if opset_expr.id in self.globals: - opset = self.globals[opset_expr.id] - if isinstance(opset, values.Opset) and opset.domain == "": - return opset - for child in ast.iter_child_nodes(node): - res = self._find_onnx_opset(child) - if res is not None: - return res - return None - def _init_function_translation(self) -> None: - """Initialize self for translating a new (top-level) function.""" - self._outer = [] - self._current_fn: Optional[irbuilder.IRFunction] = None - self._nextvar = 0 - self._used_vars = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + # TODO(justinchuby): Update ir version to be user defined + # TODO(justinchuby): Maybe just store a list of functions + self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) + + # A stack of functions in the outer scope + self._outer: list[ir.Function] = [] + self._current_fn: ir.Function = ir.Function( + domain=self._opset.domain, + name="", + graph=ir.Graph((), (), nodes=[]), + attributes={}, + ) + # A mapping from value names to the values for each function + # self._scoped_values: dict[ir.Function, dict[str, ir.Value]] = {} + self._nextvar: int = 0 + self._used_vars: set[str] = set() + self._locals: list[dict[str, LocalSymValue]] = [{}] + self._finalized = False + self._value_env = _ValueEnvironment(self) + self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta) def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: - return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) + return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) def _message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error _message containing source information about an ast node.""" @@ -256,39 +305,46 @@ def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self._outer.insert(0, self._current_fn) - self._current_fn = self.ir_builder.new_function(name) - self._locals.insert(0, {}) + self._outer.append(self._current_fn) + assert self._opset is not None + self._current_fn = ir.Function( + domain=self._opset.domain, + name=name, + graph=ir.Graph((), (), nodes=[]), + attributes={}, + ) + self._locals.append({}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) - def _exit_scope(self) -> irbuilder.IRFunction: + def _exit_scope(self) -> ir.Function: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn - self._current_fn = self._outer.pop(0) - self._locals.pop(0) + self._current_fn = self._outer.pop() + self._locals.pop() + assert graph is not None return graph def _current_scope(self) -> Dict[str, LocalSymValue]: - return self._locals[0] + return self._locals[-1] def _bind(self, name: str, val: LocalSymValue) -> None: logger.debug("Converter:_bind:%s", name) - self._locals[0][name] = val + self._locals[-1][name] = val def _lookup( self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True ) -> SymValue: - for scope in self._locals: + for scope in reversed(self._locals): if name in scope: return scope[name] - if name in self.globals: - return self.globals[name] + if name in self._globals: + return self._globals[name] if raise_exception: raise ValueError(info.msg(f"Unbound name: {name}.")) return None - def generate_unique_name(self, candidate: str = "tmp") -> str: + def _generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate while r in self._used_vars: @@ -297,149 +353,63 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - def _make_onnx_attr( - self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> irbuilder.IRAttributeValue: - def tensor_name_generator() -> str: - """Return name to be used for tensor, if we need to create one.""" - return self.generate_unique_name(f"attr_{attrname}") - - proto = autocast.pyvalue_to_onnx_attribute( - attrname, attrval, tensor_name_generator, attrtype - ) - return self.ir_builder.make_attr(proto) - - def _to_onnx_attr_ref( - self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] - ) -> irbuilder.IRAttributeValue: - pytype = val.typeinfo - attrtype = ta.pytype_to_attrtype(pytype) - attrname = None - if attrtype is onnx.AttributeProto.FLOAT: - attrname = "value_float" - elif attrtype is onnx.AttributeProto.INT: - attrname = "value_int" - elif attrtype is onnx.AttributeProto.STRING: - attrname = "value_string" - elif attrtype is onnx.AttributeProto.INTS: - attrname = "value_ints" - else: - msg = f"Unsupported attribute type {pytype!r}." - fail(info.msg(msg) if info else msg) - return self.ir_builder.make_attr_ref(attrname, val.value, pytype) - - def _to_onnx_var( - self, - val: values.SymbolValue | PyValue, - target: Optional[PreferredName] = None, - info: Optional[sourceinfo.SourceInfo] = None, - ) -> Variable: - if isinstance(val, values.AttrRef): - # promote attribute to value - result = self.generate_unique_name(target or "tmp") - attr = self._to_onnx_attr_ref(val, info) - self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr]) - if ta.base_type_is_bool(val.typeinfo): - # ONNX attributes use an int-encoding for bools, but ONNX tensor types - # distinguish between int and bool. So we cast the int tensor to a bool tensor, - # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self.generate_unique_name(result + "_as_bool") - cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) - self.emit( - [result_as_bool], - values.Op(self.default_opset, "Cast"), - [result], - [cast_attr], - ) - return Variable(result_as_bool, True) - return Variable(result, True) - if isinstance(val, values.Dynamic): - return Variable(val.value) - # Assume value is a python-value convertible to a tensor - # TODO: check if value is convertible to a TensorProto, so that we can - # produce a better error _message otherwise - return self._emit_const(val, target or "tmp", info) - def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: + """Convert a python variable to an ONNX variable.""" return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) def emit( self, outputs: Sequence[str], - callee: values.Op | str, - inputs: Sequence[Optional[str]], - attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, - ): - if not isinstance(callee, values.Op): - callee = values.Op(self.default_opset, callee) - if attrs is None: - attrs = [] - if sub_functions is None: - sub_functions = {} - self.ir_builder.add_stmt( - self._current_fn, - outputs, - callee, - inputs, - attrs, - sub_functions, + op_type: str, + inputs: Sequence[str], + *, + attrs: Sequence[ir.Attr] = (), + domain: str = "", + ) -> Sequence[ir.Value]: + """Emit an ONNX operator with the given inputs, outputs, and attributes.""" + node = ir.Node( + domain=domain, + op_type=op_type, + inputs=[self._lookup(inp, self._source_of(inp)) for inp in inputs], + attributes=attrs, + outputs=[self._lookup(out, self._source_of(out)) for out in outputs], ) + self._current_fn.append(node) + return node.outputs - def _emit_const( + def emit_const( self, pyvalue: PyValue, - suggested_name: Optional[PreferredName], + suggested_name: PreferredName | None, info: sourceinfo.SourceInfo, - ) -> Variable: + ) -> ir.Value: + """Emit a constant value as an ONNX Constant node.""" + # Obtain a name for the constant if suggested_name is None: if isinstance(pyvalue, int): - if pyvalue >= 0: - suggested_name = f"int64_{pyvalue}" - else: - suggested_name = f"int64_m{abs(pyvalue)}" + suggested_name = f"int64_{pyvalue}" elif ( isinstance(pyvalue, list) and len(pyvalue) == 1 and isinstance(pyvalue[0], int) ): - if pyvalue[0] >= 0: - suggested_name = f"int64_{pyvalue[0]}_1d" - else: - suggested_name = f"int64_m{abs(pyvalue[0])}_1d" + suggested_name = f"int64_{pyvalue[0]}_1d" else: suggested_name = "const" - ovar = self.generate_unique_name(suggested_name) + var_name = self._generate_unique_name(suggested_name) + + # Create a tensor from the python value try: - tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue) - except ValueError as e: + tensor = ir.tensor(pyvalue, name=var_name) + except Exception as e: fail(info.msg(str(e))) - attr = self._make_onnx_attr("value", tensor) - self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - return Variable(ovar, True) - def _emit_copy(self, original_var: str, suggested_name: str) -> str: + const = self.emit([var_name], "Constant", [], attrs=[ir.AttrTensor("value", tensor)])[0] + mark_castable(const) + return const + + def _emit_copy(self, original_var: str, suggested_name: str) -> ir.Value: """Emits a copy statement, using the ONNX Identity operator.""" - new_var = self.generate_unique_name(suggested_name) - self.emit([new_var], "Identity", [original_var]) - return new_var - - def _is_constant_expr(self, node: ast.AST) -> None: - if isinstance(node, ast.UnaryOp): - return self._is_constant_expr(node.operand) - if isinstance( - node, - ( - ast.Call, - ast.BinOp, - ast.UnaryOp, - ast.Compare, - ast.Attribute, - ast.List, - ast.Load, - ast.Constant, - ), - ): - return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node)) - return False + new_var = self._generate_unique_name(suggested_name) + return self.emit([new_var], "Identity", [original_var])[0] def _eval_constant_expr(self, expr: ast.AST) -> PyValue: """Evaluates a sub-expression that is assumed to represent a constant value. @@ -451,18 +421,19 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: as divergence between eager-mode execution and evaluation of the ONNX function.) """ - # TODO: assert (self._is_constant_expr(expr)) - # TODO: Refine types + # TODO: assert (_is_constant_expr(expr)) + # TODO(justinchuby): Expand locals? locals: dict[Any, Any] = {} + # TODO(justinchuby): Find a better way to pass lineno and col_offset expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) cpl = compile(expr, filename="", mode="eval") try: - return eval(cpl, self.globals, locals) # pylint: disable=eval-used + return eval(cpl, self._globals, locals) # pylint: disable=eval-used except NameError as e: raise NameError( self._message( expr, - f"Missing names, globals contains {list(self.globals)!r}, " + f"Missing names, globals contains {list(self._globals)!r}, " f"locals {list(locals)!r}.", ) ) from e @@ -471,41 +442,46 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None, - ) -> Optional[irbuilder.IRAttributeValue]: - """Translate an attribute-value specification of the form `attr_name=` - in a call to an op. expr is an AST. The following cases are supported: + # TODO(justinchuby): Is attr_meta needed? + attr_meta: ir.Attr | None = None, + ) -> ir.Attr | None: + """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. + + The following cases are supported: * Expr evaluates to a script-time constant (a python-value) that can be mapped into an ONNX attribute value, or * Expr evaluates to None, in which case None is returned, or * Expr must be an attribute-reference, that is a name representing an attribute-parameter of a containing function. """ - if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): - attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo) + attr_ref = _to_onnx_ref_attr(val, val.typeinfo) if attr_meta is not None and (attr_ref.type != attr_meta.type): self.fail( expr, f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) return attr_ref - if isinstance(val, irbuilder.IRFunction): + if isinstance(val, ir.Graph): + # if isinstance(val, irbuilder.IRFunction): # Check that outer-scope variables referenced by function have same value # at function-definition site and use-as-attribute site, to avoid errors. - for pyvar, previous in val.outer_scope_variables: - current = self._lookup(pyvar, self._source_of(expr)) - if current.value != previous.value: - self.fail( - expr, - f"Outer scope variable '{pyvar}' referenced by function " - f"'{expr.id!r}' modified.", - ) - # Create GraphProto attribute - val = val.to_graph_proto() + # TODO(justinchuby): Capture outer_scope_variables? + # And implement the following + # for pyvar, previous in val.outer_scope_variables: + # current = self._lookup(pyvar, self._source_of(expr)) + # if current.value != previous.value: + # self.fail( + # expr, + # f"Outer scope variable '{pyvar}' referenced by function " + # f"'{expr.id!r}' modified.", + # ) + + # Create Graph attribute + pass else: val = self._eval_constant_expr(expr) @@ -515,25 +491,11 @@ def _translate_attr( # The caller is responsible for omitting such attribute-values from the list of attributes # in a NodeProto. if val is None: - if attr_meta and attr_meta.required: - self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = int(attr_meta.type) if attr_meta else None - attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type) - if attr_meta and (attr.type != attr_meta.type): - self.fail( - expr, - f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'", - ) - return attr - - def _translate_docstring(self, node: ast.Expr) -> None: - if hasattr(node.value, "value"): - # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." + attr = ir.convenience.convert_attribute( + attr_name, val, attr_type=attr_meta.type if attr_meta else None ) + return attr def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None @@ -554,8 +516,8 @@ def _translate_expr( r = self._translate_name_expr(node) elif isinstance(node, ast.Subscript): r = self._translate_subscript_expr(node, target) - elif self._is_constant_expr(node): - r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node)) + elif _is_constant_expr(node): + r = self.emit_const(self._eval_constant_expr(node), target, self._source_of(node)) else: raise ValueError( self._message(node, f"Unsupported expression type {type(node)!r}.") @@ -565,8 +527,8 @@ def _translate_expr( callee, args, attrs = r target = "tmp" if target is None else target assert isinstance(target, str) - result = self.generate_unique_name(target) - self.emit([result], callee, args, attrs) + result = self._generate_unique_name(target) + self.emit([result], callee, args, attrs=attrs) return Variable(result) def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: @@ -620,7 +582,7 @@ def _translate_subscript_expr( var_name = var.name if target is None: target = f"{var_name}_subscripted" - target = self.generate_unique_name(target) + target = self._generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) info = self._source_of(node.slice) @@ -631,7 +593,7 @@ def _translate_subscript_expr( def const_1d(value, name: Optional[str] = None): nonlocal cached_int_consts if value not in cached_int_consts: - cached_int_consts[value] = self._emit_const([value], name, info) + cached_int_consts[value] = self.emit_const([value], name, info) return cached_int_consts[value] def one_1d(): @@ -653,7 +615,7 @@ def translate_slice_component( ) return const_1d(default_value), default_value - if self._is_constant_expr(node_arg): + if _is_constant_expr(node_arg): cst = self._eval_constant_expr(node_arg) if isinstance(cst, int): return const_1d(cst), cst @@ -661,12 +623,11 @@ def translate_slice_component( raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: name = self._translate_expr(node_arg).name - reshaped = self.generate_unique_name(f"{name}_reshaped") + reshaped = self._generate_unique_name(f"{name}_reshaped") self.emit( [reshaped], - values.Op(self.default_opset, "Reshape"), + "Reshape", [name, one_1d().name], - [], ) return reshaped, None @@ -705,9 +666,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # Add to sliced_indices, unless it is "::", which is a no-op. if not (elt.lower is None and elt.upper is None and elt.step is None): sliced_indices.append((axis, elt)) - elif self._is_constant_expr(elt) and isinstance( - self._eval_constant_expr(elt), int - ): + elif _is_constant_expr(elt) and isinstance(self._eval_constant_expr(elt), int): scalar_indices.append((axis, elt)) else: non_scalar_indices.append((axis, elt)) @@ -748,18 +707,18 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: steps.append(inputs[2]) if len(starts) > 1: - axis_0_attr = self._make_onnx_attr("axis", 0) - start_name = self.generate_unique_name(f"{var_name}_start") - self.emit([start_name], "Concat", starts, [axis_0_attr]) + axis_0_attr = ir.AttrInt64("axis", 0) + start_name = self._generate_unique_name(f"{var_name}_start") + self.emit([start_name], "Concat", starts, attrs=[axis_0_attr]) - end_name = self.generate_unique_name(f"{var_name}_end") - self.emit([end_name], "Concat", ends, [axis_0_attr]) + end_name = self._generate_unique_name(f"{var_name}_end") + self.emit([end_name], "Concat", ends, attrs=[axis_0_attr]) - axes_name = self.generate_unique_name(f"{var_name}_axis") - self.emit([axes_name], "Concat", axes, [axis_0_attr]) + axes_name = self._generate_unique_name(f"{var_name}_axis") + self.emit([axes_name], "Concat", axes, attrs=[axis_0_attr]) - steps_name = self.generate_unique_name(f"{var_name}_step") - self.emit([steps_name], "Concat", steps, [axis_0_attr]) + steps_name = self._generate_unique_name(f"{var_name}_step") + self.emit([steps_name], "Concat", steps, attrs=[axis_0_attr]) else: start_name = starts[0] end_name = ends[0] @@ -767,23 +726,23 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: steps_name = steps[0] if squeezed_axes: - sliced_name = self.generate_unique_name(f"{var_name}_sliced") + sliced_name = self._generate_unique_name(f"{var_name}_sliced") self.emit( [sliced_name], "Slice", [var_name, start_name, end_name, axes_name, steps_name], ) - squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) + squeezed_axes = self.emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze - result = self.generate_unique_name(f"{var_name}_squeezed") + result = self._generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target result = target self.emit([result], "Squeeze", [sliced_name, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice - result = self.generate_unique_name(f"{var_name}_sliced") + result = self._generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target result = target slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] @@ -798,14 +757,14 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) - axis_attr = self._make_onnx_attr("axis", axis) + axis_attr = ir.AttrInt64("axis", axis) # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather - gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") + gathered = self._generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) + self.emit([gathered], "Gather", [str(result), index_value], attrs=[axis_attr]) result = gathered return Variable(result) @@ -813,17 +772,18 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: def _translate_call_expr(self, node: ast.Call): """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) - param_schemas = callee.param_schemas() + op_signature = callee.op_signature # If the callee's schema is available, we use it to determine the inputs and attributes. # Otherwise, we map named arguments to attributes and positional arguments to inputs. - if param_schemas: - kwargs = {x.arg: x.value for x in node.keywords} - args, attrs = param_manipulation.separate_input_attributes_from_arguments( - param_schemas, node.args, kwargs, fill_defaults=False - ) - args = [self._translate_opt_expr(x) for x in args] + if op_signature is not None: + args = node.args + kwargs: dict[str, ast.expr] = {x.arg: x.value for x in node.keywords} + # First separate inputs from attributes. This is needed because in Python + # it is possible to pass onnx inputs as kwargs + inputs, attrs = _separate_inputs_and_attrs(op_signature, args, kwargs) + onnx_inputs = [self._translate_opt_expr(x) for x in inputs] attrs = [ - self._translate_attr(x, y, callee.op_schema.attributes[x]) + self._translate_attr(x, y, op_signature.params_map[x]) for x, y in attrs.items() ] else: @@ -843,28 +803,28 @@ def _cast_like_binary_expression(self, op, left, right): def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) attr = [] - if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): + if isinstance(node.op, ast.Mod) and _is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): - attr = [self._make_onnx_attr("fmod", 1)] + attr = [ir.AttrInt64("fmod", 1)] - op = values.Op(self.default_opset, primop_map[op]) + onnx_op = _PRIMOP_MAP[op] left, right = self._cast_like_binary_expression( - op, self._translate_expr(node.left), self._translate_expr(node.right) + onnx_op, self._translate_expr(node.left), self._translate_expr(node.right) ) - return op, [left, right], attr + return onnx_op, [left, right], attr def _translate_unary_op_expr(self, node): op = type(node.op) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}.")) - if self._is_constant_expr(node.operand): + if _is_constant_expr(node.operand): # This function changed the constant node.operand # and returns it. The function calling this one # should intercept this call and replace node @@ -883,29 +843,29 @@ def _translate_unary_op_expr(self, node): return self._translate_expr(cst) if op == ast.UAdd: return self._translate_expr(node.operand) - opname = primop_map[op] + opname = _PRIMOP_MAP[op] operand = self._translate_expr(node.operand) - return values.Op(self.default_opset, opname), [operand], [] + return values.Op(self._default_opset, opname), [operand], [] def _translate_compare_expr(self, node): # TODO: handle multiple comparisons in one expression assert len(node.ops) == 1 assert len(node.comparators) == 1 op = type(node.ops[0]) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) - opname = primop_map[op] + opname = _PRIMOP_MAP[op] left = self._translate_expr(node.left) right = self._translate_expr(node.comparators[0]) # NotEqual is not a standard ONNX op, and needs to be translated into # an Equal op/node followed by a Not op/node. - op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal") + op = values.Op(self._default_opset, opname if opname != "NotEqual" else "Equal") left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": - tmp = self.generate_unique_name() + tmp = self._generate_unique_name() self.emit([tmp], op, [left, right]) - not_op = values.Op(self.default_opset, "Not") + not_op = values.Op(self._default_opset, "Not") return not_op, [tmp], [] return op, [left, right], [] @@ -945,12 +905,12 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable if isinstance(found, values.Op): return found if not found: - if function_name not in self.default_opset: + if function_name not in self._default_opset: warn( f"Unknown function name {function_name!r}. " f"The ONNX graph may not work." ) - return values.Op(self.default_opset, function_name) + return values.Op(self._default_opset, function_name) self.fail(node, "Invalid callee") def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: @@ -974,8 +934,6 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: if isinstance(node, (ast.For, ast.While)): return self._translate_loop_stmt(node) if ast_utils.is_doc_string(node): - if index_of_stmt == 0: - return self._translate_docstring(node) return None if isinstance(node, ast.FunctionDef): return self._translate_nested_function_def(node) @@ -1008,7 +966,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") - onnx_name = self.generate_unique_name(x.id) + onnx_name = self._generate_unique_name(x.id) self._bind( x.id, values.Dynamic( @@ -1018,7 +976,7 @@ def generate_onnx_name(x: ast.AST): return onnx_name outputs = [generate_onnx_name(x) for x in lhs.elts] - self.emit(outputs, callee, inputs, attrs) + self.emit(outputs, callee, inputs, attrs=attrs) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") @@ -1060,10 +1018,11 @@ def check_num_outputs(n): ) ) - def ret(exp, i, suffix): + def ret(exp: ast.AST, i: int, suffix: str) -> str: preferred_name = f"return_val{suffix}" return_var = self._translate_expr(exp, preferred_name).name val = self._lookup(return_var, self._source_of(exp), False) + assert type(val) is values.Dynamic if val and val.kind == values.DynamicKind.Input: # In ONNX, a graph-input cannot be an output of the graph. # We need to insert a copy. @@ -1071,13 +1030,17 @@ def ret(exp, i, suffix): for prev_output in self._current_fn.outputs: if prev_output.name == return_var: # ONNX does not allow duplicate output names. + # TODO(justinchuby): Maybe pass in ir.Value in _emit_copy return_var = self._emit_copy(return_var, f"{return_var}_copy") break if self.returntype is None: t = None else: t = self.returntype[i] - self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) + self._current_fn.outputs.append(return_var) + # TODO(justinchuby): Set type for return var from t + # TODO(justinchuby): Get self._source_of(stmt) + # self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) return return_var val = stmt.value @@ -1089,25 +1052,27 @@ def ret(exp, i, suffix): return ret(val, 0, "") def _translate_if_stmt(self, stmt: ast.If) -> None: - if hasattr(stmt, "live_out"): + if (live_out := self.meta[stmt].live_out) is not None: live_defs = list( - stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message)) + live_out.intersection(_analysis.assigned_vars(stmt, self._message)) ) else: - live_defs = list(analysis.assigned_vars(stmt, self._message)) + live_defs = list(_analysis.assigned_vars(stmt, self._message)) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno - thenGraph, sub_fct_then = self._translate_block( - stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt + + # TODO(justinchuby): Ensure the values are obtained from the live_defs + then_graph, sub_fct_then = self._translate_block( + stmt.body, f"then_graph_{lineno}", live_defs, parent_stmt=stmt ) - thenAttr = self._make_onnx_attr("then_branch", thenGraph) - elseGraph, sub_fct_else = self._translate_block( - stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt + then_attr = ir.AttrGraph("then_branch", then_graph) + else_graph, sub_fct_else = self._translate_block( + stmt.orelse, f"else_graph_{lineno}", live_defs, parent_stmt=stmt ) - elseAttr = self._make_onnx_attr("else_branch", elseGraph) + else_attr = ir.AttrGraph("else_branch", else_graph) def rename(x): - r = self.generate_unique_name(x) + r = self._generate_unique_name(x) self._bind( x, values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)), @@ -1117,19 +1082,20 @@ def rename(x): # no break condition renamed = [rename(x) for x in live_defs] if not renamed: - self.fail(stmt, "A subgraph for a test do not have any output variable.") + # TODO(justinchuby): This needs comments. What is it doing? + self.fail(stmt, "A subgraph for an if condition has no outputs.") + # TODO(justinchuby): Collect the subfunctions to self sub_functions = {} sub_functions.update(sub_fct_then) sub_functions.update(sub_fct_else) if renamed == [test]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( - renamed, - values.Op(self.default_opset, "If"), [test], - [thenAttr, elseAttr], - sub_functions=sub_functions, + "If", + renamed, + attrs=[then_attr, else_attr], ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: @@ -1151,7 +1117,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name - o_cond_var = self.generate_unique_name("cond_in") + o_cond_var = self._generate_unique_name("cond_in") i_cond_var = o_cond_var cond_while = None o_loop_condition = "" # No condition for a for loop. @@ -1174,18 +1140,19 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: else: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body - exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message) - vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message) - loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) + exposed_uses = _analysis.exposed_uses(loop_stmt.body, self._message) + vars_def_in_loop = _analysis.assigned_vars(loop_stmt.body, self._message) + live_out = self.meta[loop_stmt].live_out or set() + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) # loop-condition: - # o_loop_condition = self._emit_const(True, "true", self._source_of(loop_stmt)) + # o_loop_condition = self.emit_const(True, "true", self._source_of(loop_stmt)) # build loop_body self._enter_scope("loop_body", loop_stmt) - o_loop_var = self.generate_unique_name(p_loop_var) + o_loop_var = self._generate_unique_name(p_loop_var) self.ir_builder.add_input( self._current_fn, o_loop_var, @@ -1205,7 +1172,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) for pv in loop_state_vars: - ov = self.generate_unique_name(pv) + ov = self._generate_unique_name(pv) # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None @@ -1246,7 +1213,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: continue self._translate_stmt(s) - o_cond_out = self.generate_unique_name("cond_out") + o_cond_out = self._generate_unique_name("cond_out") if cond_while is not None: # Loop while @@ -1261,9 +1228,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.emit( [o_cond_out], - values.Op(self.default_opset, operator_name), + operator_name, [condition_name or o_cond_var], - [], ) self.ir_builder.add_output( @@ -1296,7 +1262,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: info = self._source_of(loop_stmt) def rename(x): - r = self.generate_unique_name(x) + r = self._generate_unique_name(x) self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info)) return r @@ -1305,8 +1271,8 @@ def rename(x): onnx_outputs, "Loop", inputs, - attrs, - sub_functions=sub_functions, + attrs=attrs, + # sub_functions=sub_functions, ) def _translate_block( @@ -1338,7 +1304,7 @@ def _translate_block( ) else: pv_val = None - for scope in self._locals: # TODO: skip _current_scope + for scope in reversed(self._locals): # TODO: skip _current_scope if pvar in scope: pv_val = scope[pvar] break @@ -1360,9 +1326,9 @@ def _translate_block( def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" self._enter_scope(fn.name, fn) - self._translate_function_def_common(fn) + self._translate_function_def(fn) function_ir = self._exit_scope() - outer_scope_vars = analysis.outer_scope_variables(fn, self._message) + outer_scope_vars = _analysis.outer_scope_variables(fn, self._message) function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] @@ -1370,9 +1336,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: # TODO: Does not yet handle nested functions within nested functions. self._current_fn.add_nested_function(function_ir) - def _translate_function_signature_common( - self, fn: ast.FunctionDef - ) -> irbuilder.IRFunction: + def _translate_function_signature_common(self, fn: ast.FunctionDef) -> ir.Function: """Translate a function signature (top-level or nested).""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: @@ -1431,32 +1395,171 @@ def _translate_function_signature_common( return self._current_fn - def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function: """Translate a function definition, including the signature and its body.""" - logger.debug("Converter:_translate_function_def_common:%s", fn.name) - _ = self._translate_function_signature_common(fn) - for i, s in enumerate(fn.body): + logger.debug("Converter:_translate_function_def:%s", node.name) + _ = self._translate_function_signature_common(node) + for i, s in enumerate(node.body): self._translate_stmt(s, index_of_stmt=i) + + # Update docstring if available + if docstring := ast.get_docstring(node): + self._current_fn.doc_string = docstring return self._current_fn - def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: - if isinstance(stmt, ast.FunctionDef): - self._init_function_translation() - if self.default_opset_ is None: - opset = self._find_onnx_opset(stmt) - if opset: - self._set_default_opset(opset, stmt) - domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) - analysis.do_liveness_analysis(stmt, self._message) - fn_ir = self._translate_function_def_common(stmt) - fn_ir.debug_print() - self.this_module.add_function_def(fn_ir) - return fn_ir - raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") - - def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: - """Translate a (top-level) function signature.""" - domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(fn.name, domain, True) - return self._translate_function_signature_common(fn) + def _finalize(self) -> None: + self._finalized = True + + def convert(self) -> ir.Function: + """Convert the Python AST to an ONNX IR function.""" + if self._finalized: + return self._current_fn + + func_def = self._ast_root + _analysis.do_liveness_analysis(func_def, self._message, self.meta) + return self._translate_function_def(func_def) + # TODO(justinchuby): Handle function registration to the opset + # self._opset.add_function_def(fn_ir) + + +def _is_constant_expr(node: ast.AST) -> bool: + """Check if the AST node is a constant expression.""" + if isinstance(node, ast.UnaryOp): + return _is_constant_expr(node.operand) + if isinstance( + node, + ( + ast.Call, + ast.BinOp, + ast.UnaryOp, + ast.Compare, + ast.Attribute, + ast.List, + ast.Load, + ast.Constant, + ), + ): + return all(_is_constant_expr(c) for c in ast.iter_child_nodes(node)) + return False + + +def _separate_inputs_and_attrs( + signature: _schemas.OpSignature, + args: Sequence[ast.expr], + kwargs: Mapping[str, ast.expr], +) -> tuple[Sequence[ast.expr], dict[str, ast.expr]]: + """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. + + This function uses the OpSignature to determine which argument in args and kwargs corresponds to + which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are + stored in named_attrs. If an _optional input_ is not provided, it is filled with None. + + Args: + signature: The OpSignature for the node. + args: The positional arguments for the node. + kwargs: The keyword arguments for the node. + + Returns: + A tuple of two mappings: named_inputs and named_attrs. + + Raises: + ValueError: If a required parameter is not provided. + """ + # 1. Construct inputs, attrs based on (args, kwargs) and the signature. + # a. Loop over all parameters in the signature and args together + # b. Depending on param.is_input, Record inputs or named_attrs[param.name] = arg + # c. Handle kwargs as well + inputs_reversed: Sequence[Any] = [] + named_attrs: dict[str, Any] = {} + reversed_args_stack = list(reversed(args)) + for param in signature.params: + if isinstance(param, _schemas.Parameter): + # Handle inputs + if reversed_args_stack: + # First exhaust the positional arguments + if param.variadic: + # Handle variadic arguments + inputs_reversed = [*reversed(args)] + reversed_args_stack.clear() + else: + inputs_reversed.append(reversed_args_stack.pop()) + elif param.name in kwargs: + inputs_reversed.append(kwargs[param.name]) + elif param.required: + raise ValueError( + f"Required parameter '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional parameter '%s' is not provided. Added as None. Signature: %s", + param.name, + signature, + ) + inputs_reversed.append(None) + else: + # Handle attributes + attribute: ir.Attr | None + assert isinstance(param, _schemas.AttributeParameter), ( + f"Expected AttributeParameter, got {type(param)}" + ) + if reversed_args_stack: + # First exhaust the positional arguments + attribute = reversed_args_stack.pop() # type: ignore[assignment] + elif kwargs.get(param.name) is not None: + attribute = kwargs[param.name] # type: ignore[assignment] + else: + if param.required: + raise ValueError( + f"Required attribute '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional attribute '%s' is None. Dropped. Signature: %s", + param.name, + signature, + ) + continue + named_attrs[param.name] = attribute + return tuple(reversed(inputs_reversed)), named_attrs + + +def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr: + """Convert an attribute reference to an ONNX ref attribute.""" + + # TODO(justinchuby): Consider using a convenience function + pytype = val.typeinfo + attrtype = _schemas.get_attr_type(pytype) + attrname = None + if attrtype is ir.AttributeType.FLOAT: + attrname = "value_float" + elif attrtype is ir.AttributeType.INT: + attrname = "value_int" + elif attrtype is ir.AttributeType.STRING: + attrname = "value_string" + elif attrtype is ir.AttributeType.INTS: + attrname = "value_ints" + else: + msg = f"Unsupported attribute type {pytype!r}." + fail(info.msg(msg) if info else msg) + # TODO(justinchuby): What is the ref attr name? + return ir.RefAttr(attrname, val.value, attrtype) + + +def _find_onnx_opset(node: ast.AST, globals: dict[str, Any]) -> values.Opset | None: + """Find the (first) ONNX opset used in the function, if any.""" + # Search for a Call expression of form "op.OpName(...)" + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + opset_expr = node.func.value + if isinstance(opset_expr, ast.Name): + if opset_expr.id in globals: + opset = globals[opset_expr.id] + if isinstance(opset, values.Opset) and opset.domain == "": + return opset + for child in ast.iter_child_nodes(node): + res = _find_onnx_opset(child, globals) + if res is not None: + return res + return None diff --git a/onnxscript/converter_test.py b/onnxscript/_converter_test.py similarity index 99% rename from onnxscript/converter_test.py rename to onnxscript/_converter_test.py index 9a7ca504a7..ff8aaca591 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/_converter_test.py @@ -23,7 +23,7 @@ import onnxscript import onnxscript.testing -from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor +from onnxscript import BOOL, FLOAT, INT64, _converter, graph, script, tensor from onnxscript.onnx_opset import opset11 as op11 from onnxscript.onnx_opset import opset15 as op from tests.common import onnx_script_test_case, testutils @@ -437,12 +437,12 @@ def check_failure(self, f, msg): global_names = globals().copy() top_level_ast = ast.parse(source) f_ast = top_level_ast.body[0] - cvt = converter.Converter( + cvt = _converter.Converter( opset=op, global_names=global_names, source=source, default_opset=op ) try: cvt.translate_function_def(f_ast) - except converter.TranslationError as e: + except _converter.TranslationError as e: if msg not in str(e): raise AssertionError(f"Unable to find {msg!r} in {e!r} in\n{source}") from e return diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/_analysis.py similarity index 87% rename from onnxscript/_internal/analysis.py rename to onnxscript/_internal/_analysis.py index 0403f60c91..50e1cba1bd 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/_analysis.py @@ -1,13 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Analysis utilities for Python AST.""" from __future__ import annotations import ast -from typing import Any, Optional, Sequence, Set +from typing import Any, Optional, Sequence, TYPE_CHECKING +from collections import defaultdict from onnxscript import sourceinfo from onnxscript._internal import ast_utils +if TYPE_CHECKING: + from onnxscript import _converter + def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: if not isinstance(for_stmt.target, ast.Name): @@ -15,7 +20,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: return for_stmt.target.id -def _used_vars(expr: Optional[ast.expr]) -> Set[str]: +def _used_vars(expr: Optional[ast.expr]) -> set[str]: """Return set of all variables used, including function names, in an expression.""" if expr is None: return set() @@ -35,7 +40,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def _lhs_vars(lhs: ast.expr) -> Set[str]: +def _lhs_vars(lhs: ast.expr) -> set[str]: """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): @@ -49,12 +54,12 @@ def get_id(e): def assigned_vars( stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: +) -> set[str]: """Return the set of all variables that may be assigned to in an execution of input stmt or sequence of statements. """ - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]: result: set[Any] = set() for s in block: result = result | assigned_vars(s, formatter) @@ -84,20 +89,26 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: raise ValueError(error_message) -def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Perform liveness analysis of the given function-ast. The results of the - analysis are stored directly with each statement-ast `s` as attributes `s.live_in` - and `s.live_out`. +def do_liveness_analysis( + fun: ast.FunctionDef, + formatter: sourceinfo.Formatter, + meta: defaultdict[ast.AST, _converter.ASTMeta], +): + """Perform liveness analysis of the given function-ast. + + The results of the analysis are stored in the `meta` dictionary, which maps + each AST node to its metadata. The metadata includes the set of live variables + at the entry and exit of each node. """ - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - stmt.live_out = live_out # type: ignore[attr-defined] + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + meta[stmt].live_out = live_out live = do_visit(stmt, live_out) - stmt.live_in = live # type: ignore[attr-defined] + meta[stmt].live_in = live return live - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for s in reversed(block): live_out = visit(s, live_out) return live_out @@ -165,12 +176,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): (in the first statement). Hence x is included in the exposed_uses. """ - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for stmt in reversed(block): live_out = visit(stmt, live_out) return live_out - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: if isinstance(stmt, ast.Assign): return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/_analysis_test.py similarity index 95% rename from onnxscript/_internal/analysis_test.py rename to onnxscript/_internal/_analysis_test.py index 74e7ca4c18..64e9b5b110 100644 --- a/onnxscript/_internal/analysis_test.py +++ b/onnxscript/_internal/_analysis_test.py @@ -6,7 +6,7 @@ import unittest from typing import Any -from onnxscript._internal import analysis, ast_utils +from onnxscript._internal import _analysis, ast_utils from onnxscript.onnx_opset import opset15 as op from onnxscript.sourceinfo import formatter @@ -30,7 +30,7 @@ def generic_visit(self, node): class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): source, parse_tree = ast_utils.get_src_and_ast(fun) - analysis.do_liveness_analysis(parse_tree, formatter(source)) + _analysis.do_liveness_analysis(parse_tree, formatter(source)) visitor = AnalysisResultsVisitor() visitor.visit(parse_tree) return visitor.results @@ -113,7 +113,7 @@ def while_eg(x): class TestExposedUses(unittest.TestCase): def assertUses(self, f, expected): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.exposed_uses(parse_tree.body, formatter(source)) + result = _analysis.exposed_uses(parse_tree.body, formatter(source)) self.assertEqual(result, set(expected)) def test_basic(self): @@ -190,7 +190,7 @@ def f(x): class TestAssignedVarAnalysis(unittest.TestCase): def assert_assigned_vars(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.assigned_vars(parse_tree.body, formatter(source)) + result = _analysis.assigned_vars(parse_tree.body, formatter(source)) self.assertEqual(result, expected) def test_basic_defs(self): diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 1defac3e53..d9ad48af35 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -13,7 +13,7 @@ from onnxscript import ir, tensor if TYPE_CHECKING: - from onnxscript import converter + from onnxscript import _converter # Conversions from python values to ONNX are used by both the script converter as well # as the eager-mode runtime and both need to be consistent. The script converter converts @@ -187,16 +187,16 @@ def get_type_info(x): def static_cast_inputs( - converter_: converter.Converter, + converter_: _converter.Converter, op_schema: Optional[OpSchema], - args: Sequence[Optional[converter.Variable]], + args: Sequence[Optional[_converter.Variable]], ) -> tuple[str, ...]: """Used for autocast during script-translation. This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))" Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed. """ - def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]: + def get_type_info(x: Optional[_converter.Variable]) -> Optional[_converter.Variable]: """Returns x back if x can serve as the target-type for a cast (as the second argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. @@ -204,7 +204,7 @@ def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variabl return None if x is None or x.is_castable else x def cast_like( - x: Optional[converter.Variable], y: Optional[converter.Variable] + x: Optional[_converter.Variable], y: Optional[_converter.Variable] ) -> Optional[str]: if x is None: return None diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..2a2527e31b 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -3,6 +3,7 @@ from __future__ import annotations import collections.abc +import copy import dataclasses import inspect import logging @@ -210,7 +211,7 @@ def _is_optional(type_: type) -> bool: return False -def _get_attr_type(type_: type) -> ir.AttributeType: +def get_attr_type(type_: type) -> ir.AttributeType: """Obtain the type of the attribute from a Python class.""" try: if type_ in _PY_TYPE_TO_ATTR_TYPE: @@ -455,7 +456,7 @@ def from_function( ) else: type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED: # Construct the default attribute if param.default is not inspect.Parameter.empty: # TODO: Use ir_convenience instead to handle int as float diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 4274bf2062..8674a6331f 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -182,6 +182,18 @@ def attr_proto(self) -> onnx.AttributeProto: class IRStmt: + """An IR statement (representing an operation). + + Details: + - `result`: A sequence of variable names that this statement assigns to. + - `callee`: The operation being called, represented as an instance of `values.Op`. + - `args`: A sequence of arguments to the operation, which can be variable names or + `None` for optional arguments. + - `attrs`: A sequence of attributes for the operation, represented as `IRAttributeValue` + instances. + - `sub_functions`: A dictionary of sub-functions that this statement may call, mapping + function names to `onnx.FunctionProto` instances. + """ def __init__( self, result: Sequence[str], diff --git a/onnxscript/main.py b/onnxscript/main.py index 3ea3e50f90..15d8247530 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -11,7 +11,7 @@ from typing_extensions import ParamSpec import onnxscript -from onnxscript import converter, ir, irbuilder, values +from onnxscript import _converter, ir, irbuilder, values from onnxscript._internal import ast_utils _R = TypeVar("_R") @@ -29,7 +29,7 @@ def script_check( # See if conversion succeeds. # TODO: cleanup Converter interface/API, separating checker from # converter - convert = converter.Converter( + convert = _converter.Converter( opset=opset, global_names=global_names, source=source, diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index fb7b8a370d..7db9d11528 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -8,6 +8,7 @@ from typing import Optional, Sequence, Union import onnx +import onnx_ir as ir from onnxscript import onnx_types @@ -24,35 +25,35 @@ # Map from python type to corresponding ONNX AttributeProto type _PYTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOAT, - int: onnx.AttributeProto.INT, - str: onnx.AttributeProto.STRING, - bool: onnx.AttributeProto.INT, # experimental + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, # experimental } # Map from python type to corresponding ONNX AttributeProto type, # for repeated (i.e., list of) values _LISTTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOATS, - int: onnx.AttributeProto.INTS, - str: onnx.AttributeProto.STRINGS, - bool: onnx.AttributeProto.INTS, # experimental + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, # experimental } _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) # Map from ONNX AttributeProto type to its representation (in ONNX Script). _ATTRTYPE_TO_REPR = { - onnx.AttributeProto.FLOAT: "float", - onnx.AttributeProto.INT: "int", - onnx.AttributeProto.STRING: "str", - onnx.AttributeProto.FLOATS: "Sequence[float]", - onnx.AttributeProto.INTS: "Sequence[int]", - onnx.AttributeProto.STRINGS: "Sequence[str]", + ir.AttributeType.FLOAT: "float", + ir.AttributeType.INT: "int", + ir.AttributeType.STRING: "str", + ir.AttributeType.FLOATS: "Sequence[float]", + ir.AttributeType.INTS: "Sequence[int]", + ir.AttributeType.STRINGS: "Sequence[str]", } -def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str: +def onnx_attr_type_to_onnxscript_repr(attr_type: ir.AttributeType) -> str: if attr_type not in _ATTRTYPE_TO_REPR: supported = ", ".join( f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR @@ -95,26 +96,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP -def pytype_to_attrtype( - pytype: TypeAnnotationValue, -) -> Optional[onnx.AttributeProto.AttributeType]: - pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: - return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = typing.get_origin(pytype) - # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] - if type_constructor is typing.Union: - # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in typing.get_args(pytype) if x is not type(None)] - if len(args) == 1: - return pytype_to_attrtype(args[0]) - if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(pytype)[0] - if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: - return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None - - def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: """Returns True if base type of pytype is bool, False otherwise.""" pytype = _remove_annotation(pytype) diff --git a/onnxscript/values.py b/onnxscript/values.py index 1897ae14d5..80c7326c66 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -25,8 +25,7 @@ import onnx.defs from typing_extensions import ParamSpec -from onnxscript import converter as converter_module -from onnxscript import irbuilder, sourceinfo, type_annotation +from onnxscript import _converter, irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas @@ -638,7 +637,7 @@ def function_ir(self) -> irbuilder.IRFunction: closure = inspect.getclosurevars(self.func) global_names = module.__dict__.copy() global_names.update(closure.nonlocals) - converter = converter_module.Converter( + converter = _converter.Converter( opset=self._opset, global_names=global_names, source=src, @@ -686,90 +685,3 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # argument order from the Python function definition, which is lost in OpSchema. self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas - - -class SymbolValue: - """Represents script-time value information about named variables used in a script. - - At translation-time, the (local) variables of a script, including its parameters, - are bound to a SymbolValue. - - SymbolValues fall into the following categories: - - AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes - - Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs. - Dynamic values include input-parameters of the script, as well intermediate - values computed in the script. - - For example, consider the following script definition: - :: - - @script() - def ThresholdedRelu(X, alpha: float): - zero = op.CastLike(0, X) - return op.Where(X > alpha, X, zero) - - Here, `X` has a Dynamic value, `alpha` has an AttrRef value, and `zero` - has a Dynamic value. - - Scripts may also contain references to global variables, but the translator - does not associate a SymbolValue with them. The python value of global variables - is used directly in the translation, and such global variables are intended - to be used for limited purposes, namely: - * To identify an opset - * To represent constant-values, translated into ONNX constants. - """ - - def __init__(self, info: sourceinfo.SourceInfo) -> None: - if not isinstance(info, sourceinfo.SourceInfo): - raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.") - self.info = info - - -class AttrRef(SymbolValue): - def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo - ) -> None: - """Initializes AttrRef. - - Arguments: - attr_name: name of the attribute-parameter - typeinfo: type annotation of the attribute. - op's attributes in ONNX are usually single type or list of single type. - info: for debugging use. - """ - super().__init__(info) - self.value = attr_name - self.typeinfo = typeinfo - if not isinstance(typeinfo, (type, _GenericAlias)): - # typing._GenericAlias for List[int] and List[str], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") - self.typeinfo = typeinfo - - -class DynamicKind(IntFlag): - Unknown = 0 - Input = 1 - Output = 2 - Intermediate = 4 - Loop = 8 - - -class Dynamic(SymbolValue): - def __init__( - self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None - ) -> None: - """Initializes Dynamic. - - Arguments: - onnx_var: the name of the ONNX variable used to represent this value - kind: the DynamicKind of this variable - info: source-location information for error-messages/debugging - typeinfo: type-information for the value - """ - super().__init__(info) - assert isinstance(kind, DynamicKind) - self.value = onnx_var - self.kind = kind - self.typeinfo = typeinfo