Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
925 changes: 514 additions & 411 deletions onnxscript/converter.py → onnxscript/_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import onnxscript
import onnxscript.testing
from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor
from onnxscript import BOOL, FLOAT, INT64, _converter, graph, script, tensor
from onnxscript.onnx_opset import opset11 as op11
from onnxscript.onnx_opset import opset15 as op
from tests.common import onnx_script_test_case, testutils
Expand Down Expand Up @@ -437,12 +437,12 @@
global_names = globals().copy()
top_level_ast = ast.parse(source)
f_ast = top_level_ast.body[0]
cvt = converter.Converter(
cvt = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=op, global_names=global_names, source=source, default_opset=op
)
try:
cvt.translate_function_def(f_ast)
except converter.TranslationError as e:
except _converter.TranslationError as e:
if msg not in str(e):
raise AssertionError(f"Unable to find {msg!r} in {e!r} in\n{source}") from e
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Analysis utilities for Python AST."""
from __future__ import annotations

import ast
from typing import Any, Optional, Sequence, Set
from typing import Any, Optional, Sequence, TYPE_CHECKING
from collections import defaultdict

from onnxscript import sourceinfo
from onnxscript._internal import ast_utils

if TYPE_CHECKING:
from onnxscript import _converter


def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
if not isinstance(for_stmt.target, ast.Name):
raise TypeError(formatter(for_stmt, "For loop target must be a single variable."))
return for_stmt.target.id


def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
def _used_vars(expr: Optional[ast.expr]) -> set[str]:
"""Return set of all variables used, including function names, in an expression."""
if expr is None:
return set()
Expand All @@ -35,7 +40,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
return result


def _lhs_vars(lhs: ast.expr) -> Set[str]:
def _lhs_vars(lhs: ast.expr) -> set[str]:
"""Return set of assigned variables in the lhs of an assignment statement."""

def get_id(e):
Expand All @@ -49,12 +54,12 @@ def get_id(e):

def assigned_vars(
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
) -> Set[str]:
) -> set[str]:
"""Return the set of all variables that may be assigned to in an execution of input stmt
or sequence of statements.
"""

def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]:
result: set[Any] = set()
for s in block:
result = result | assigned_vars(s, formatter)
Expand Down Expand Up @@ -84,20 +89,26 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
raise ValueError(error_message)


def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
"""Perform liveness analysis of the given function-ast. The results of the
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
and `s.live_out`.
def do_liveness_analysis(
fun: ast.FunctionDef,
formatter: sourceinfo.Formatter,
meta: defaultdict[ast.AST, _converter.ASTMeta],
):
"""Perform liveness analysis of the given function-ast.

The results of the analysis are stored in the `meta` dictionary, which maps
each AST node to its metadata. The metadata includes the set of live variables
at the entry and exit of each node.
"""

def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
stmt.live_out = live_out # type: ignore[attr-defined]
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
meta[stmt].live_out = live_out
live = do_visit(stmt, live_out)
stmt.live_in = live # type: ignore[attr-defined]
meta[stmt].live_in = live
return live

def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
for s in reversed(block):
live_out = visit(s, live_out)
return live_out
Expand Down Expand Up @@ -165,12 +176,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter):
(in the first statement). Hence x is included in the exposed_uses.
"""

def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
for stmt in reversed(block):
live_out = visit(stmt, live_out)
return live_out

def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
if isinstance(stmt, ast.Assign):
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import unittest
from typing import Any

from onnxscript._internal import analysis, ast_utils
from onnxscript._internal import _analysis, ast_utils
from onnxscript.onnx_opset import opset15 as op
from onnxscript.sourceinfo import formatter

Expand All @@ -30,7 +30,7 @@
class TestLivenessAnalysis(unittest.TestCase):
def analyze(self, fun):
source, parse_tree = ast_utils.get_src_and_ast(fun)
analysis.do_liveness_analysis(parse_tree, formatter(source))
_analysis.do_liveness_analysis(parse_tree, formatter(source))

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'meta' in function call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
visitor = AnalysisResultsVisitor()
visitor.visit(parse_tree)
return visitor.results
Expand Down Expand Up @@ -113,7 +113,7 @@
class TestExposedUses(unittest.TestCase):
def assertUses(self, f, expected):
source, parse_tree = ast_utils.get_src_and_ast(f)
result = analysis.exposed_uses(parse_tree.body, formatter(source))
result = _analysis.exposed_uses(parse_tree.body, formatter(source))
self.assertEqual(result, set(expected))

def test_basic(self):
Expand Down Expand Up @@ -190,7 +190,7 @@
class TestAssignedVarAnalysis(unittest.TestCase):
def assert_assigned_vars(self, f, expected: set[str]):
source, parse_tree = ast_utils.get_src_and_ast(f)
result = analysis.assigned_vars(parse_tree.body, formatter(source))
result = _analysis.assigned_vars(parse_tree.body, formatter(source))
self.assertEqual(result, expected)

def test_basic_defs(self):
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from onnxscript import ir, tensor

if TYPE_CHECKING:
from onnxscript import converter
from onnxscript import _converter

# Conversions from python values to ONNX are used by both the script converter as well
# as the eager-mode runtime and both need to be consistent. The script converter converts
Expand Down Expand Up @@ -187,24 +187,24 @@


def static_cast_inputs(
converter_: converter.Converter,
converter_: _converter.Converter,
op_schema: Optional[OpSchema],
args: Sequence[Optional[converter.Variable]],
args: Sequence[Optional[_converter.Variable]],

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> tuple[str, ...]:
"""Used for autocast during script-translation.
This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))"
Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed.
"""

def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]:
def get_type_info(x: Optional[_converter.Variable]) -> Optional[_converter.Variable]:

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
"""Returns x back if x can serve as the target-type for a cast (as the second
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
castable, while X can serve as the target-type.
"""
return None if x is None or x.is_castable else x

def cast_like(
x: Optional[converter.Variable], y: Optional[converter.Variable]
x: Optional[_converter.Variable], y: Optional[_converter.Variable]

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> Optional[str]:
if x is None:
return None
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/ir/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import collections.abc
import copy

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused import copy (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

copy imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
import dataclasses
import inspect
import logging
Expand Down Expand Up @@ -210,7 +211,7 @@
return False


def _get_attr_type(type_: type) -> ir.AttributeType:
def get_attr_type(type_: type) -> ir.AttributeType:
"""Obtain the type of the attribute from a Python class."""
try:
if type_ in _PY_TYPE_TO_ATTR_TYPE:
Expand Down Expand Up @@ -455,7 +456,7 @@
)
else:
type_ = type_hints[param.name]
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
# Construct the default attribute
if param.default is not inspect.Parameter.empty:
# TODO: Use ir_convenience instead to handle int as float
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ def attr_proto(self) -> onnx.AttributeProto:


class IRStmt:
"""An IR statement (representing an operation).

Details:
- `result`: A sequence of variable names that this statement assigns to.
- `callee`: The operation being called, represented as an instance of `values.Op`.
- `args`: A sequence of arguments to the operation, which can be variable names or
`None` for optional arguments.
- `attrs`: A sequence of attributes for the operation, represented as `IRAttributeValue`
instances.
- `sub_functions`: A dictionary of sub-functions that this statement may call, mapping
function names to `onnx.FunctionProto` instances.
"""
def __init__(
self,
result: Sequence[str],
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import ParamSpec

import onnxscript
from onnxscript import converter, ir, irbuilder, values
from onnxscript import _converter, ir, irbuilder, values
from onnxscript._internal import ast_utils

_R = TypeVar("_R")
Expand All @@ -29,7 +29,7 @@
# See if conversion succeeds.
# TODO: cleanup Converter interface/API, separating checker from
# converter
convert = converter.Converter(
convert = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=opset,
global_names=global_names,
source=source,
Expand Down
51 changes: 16 additions & 35 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Sequence, Union

import onnx
import onnx_ir as ir

from onnxscript import onnx_types

Expand All @@ -24,35 +25,35 @@

# Map from python type to corresponding ONNX AttributeProto type
_PYTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOAT,
int: onnx.AttributeProto.INT,
str: onnx.AttributeProto.STRING,
bool: onnx.AttributeProto.INT, # experimental
float: ir.AttributeType.FLOAT,
int: ir.AttributeType.INT,
str: ir.AttributeType.STRING,
bool: ir.AttributeType.INT, # experimental
}

# Map from python type to corresponding ONNX AttributeProto type,
# for repeated (i.e., list of) values
_LISTTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOATS,
int: onnx.AttributeProto.INTS,
str: onnx.AttributeProto.STRINGS,
bool: onnx.AttributeProto.INTS, # experimental
float: ir.AttributeType.FLOATS,
int: ir.AttributeType.INTS,
str: ir.AttributeType.STRINGS,
bool: ir.AttributeType.INTS, # experimental
}

_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])

# Map from ONNX AttributeProto type to its representation (in ONNX Script).
_ATTRTYPE_TO_REPR = {
onnx.AttributeProto.FLOAT: "float",
onnx.AttributeProto.INT: "int",
onnx.AttributeProto.STRING: "str",
onnx.AttributeProto.FLOATS: "Sequence[float]",
onnx.AttributeProto.INTS: "Sequence[int]",
onnx.AttributeProto.STRINGS: "Sequence[str]",
ir.AttributeType.FLOAT: "float",
ir.AttributeType.INT: "int",
ir.AttributeType.STRING: "str",
ir.AttributeType.FLOATS: "Sequence[float]",
ir.AttributeType.INTS: "Sequence[int]",
ir.AttributeType.STRINGS: "Sequence[str]",
}


def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str:
def onnx_attr_type_to_onnxscript_repr(attr_type: ir.AttributeType) -> str:
if attr_type not in _ATTRTYPE_TO_REPR:
supported = ", ".join(
f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR
Expand Down Expand Up @@ -95,26 +96,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP


def pytype_to_attrtype(
pytype: TypeAnnotationValue,
) -> Optional[onnx.AttributeProto.AttributeType]:
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
type_constructor = typing.get_origin(pytype)
# Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
if type_constructor is typing.Union:
# Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
args = [x for x in typing.get_args(pytype) if x is not type(None)]
if len(args) == 1:
return pytype_to_attrtype(args[0])
if type_constructor in _LIST_CONSTRUCTORS:
elt_type = typing.get_args(pytype)[0]
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
return None


def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
"""Returns True if base type of pytype is bool, False otherwise."""
pytype = _remove_annotation(pytype)
Expand Down
Loading
Loading