-
Notifications
You must be signed in to change notification settings - Fork 93
TorchLib function authoring guide
Updated: July 2023
Authors: @justinchuby @titaiwangms
TorchLib functions are pure data. This means we avoid defining runtime behavior as code in the functions.
The main goal of torchlib is to convert PyTorch model to onnx model. Therefore, we need to understand the function signature in Pytorch first. Namely, ATen operators. native_functions.yaml defines all the native function in PyTorch.
- func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> Return
variants: function, method
dispatch:
CPU: func_cpu
CUDA: func_cuda
The developer should be careful to the ArgType. Different ArgType matches to different TypeVar in torchlib.
The decorator torch_op is used to officially register the function into torchlib.
def torch_op(
name: str | tuple[str, ...],
*,
registry: Optional[Registry] = None,
trace_only: bool = False,
private: bool = False,
complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.
Args:
name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
Default overloads should be specified by omitting the overload part,
i.e. "aten::relu" instead of "aten::relu.default".
registry: Registry to register the function to. If None, the default registry is used.
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
"""
...trace_only extends the script() to include complicated control-flow with the class TracedOnnxFunction, which instead of compiles the whole function with control flow into OnnxFunction, only traces it as a normal Python function to accommodate the unsupported control-flow.
- Name a function starting with the namespace it's from. For example,
aten_absorprims_abs. - Correctly annotate the inputs and attributes with
native_function.yaml.
Use one or create one TypeVar in tensor_typing to match the ArgType from native_functions.yaml. In most of the cases, inputs should all be tensor types, and attributes should be primitive types. However, depends on the implementation of the OnnxFunction, the scenario changes case by case to align with the requirements of used onnx operators in the function.
To script the function, every calculation in the function should be executed by an onnx operators. A prefix of opset{version} is used to represent where to get the operator. OnnxFunction also supports limited control-flow, and simplified coding, such as if, for loop, auto-constant wrapping, and auto-basic arithmetic wrapping.
@torch_op("aten::gather")
def aten_gather(
self: TReal,
dim: int,
index: TInt,
sparse_grad: bool = False, # pylint: disable=unused-argument
) -> TReal:
"""gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"""
# When (index) is empty, return (self)
if op.Size(op.Shape(index)) == 0: # Support control-flow
result = self
else:
if op.Size(op.Shape(self)) == 0: # 0 is auto-wrapping of op.Constant(value_float=[0])
self = op.Reshape(self, op.Constant(value_ints=[-1]))
if op.Size(index) == 0: # == is auto-wrapping on op.Equal()
result = op.CastLike(index, self)
else:
index = op.Cast(index, to=INT64.dtype)
result = op.GatherElements(self, index, axis=dim)
return resultThis kind of function is a pure python function covering OnnxFunction. The reason we need it is that the limited supports on coding can't solve the complicated situations for us in the operator. The needed reason could be: unsupported dictionary, unsupported len(), unsupported None check, etc ...
@torch_op("aten::layer_norm", trace_only=True)
def aten_layer_norm(
input: TReal,
normalized_shape: INT64,
weight: Optional[TReal] = None,
bias: Optional[TReal] = None,
eps: float = 1e-05,
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
# trace_only to use Python to obtain start_axis
start_axis = -len(normalized_shape)
if weight is None: # Unsupported None check
one = op.Constant(value_float=1.0)
weight = op.Expand(one, op.Shape(input, start=start_axis))
if bias is None: # Unsupported None check
zero = op.Constant(value_float=0.0)
bias = op.Expand(zero, op.Shape(input, start=start_axis))
return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps) # covers a private OnnxFunction
@torch_op("aten::layer_norm", private=True)
def _aten_layer_norm_onnx(
input: TReal,
weight: TReal,
bias: TReal,
axis: int,
eps: float = 1e-05,
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
# TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
return resultTo make sure the OnnxFunction/TracedOnnxFunction has valid implementation, we provide Op-level correctness test.
This test use PyTorch's OpInfo mechanism to generate test cases for each operator. You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804
-
To enable test cases for an operator Add a
TorchLibOpInfoentry toTORCH_LIB_OPINFOinops_test_data.py. Explicitly specifytrace_onlyif the op is trace_only. Specifycomplexif the function is designed for complex inputs.The
op_info_nameinTorchLibOpInfoneeds to be unique in the TORCH_LIB_OPINFO list, but complex=True ops can share the same name with non-complex ops because they are tested separately. -
Add
.skipand/or.xfailto skip or xfail tests. Prefer xfail over skip when possible because that allows us to monitor the behavior and update the test will it passes.2a. If a test is now failing because of xpass, because some previous errors are now fixed, removed the corresponding xfail.
-
If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input wrangler function. See
_mean_input_wranglerfor an example. -
To test different ONNX functions that are registered as overloads of the same op, use
ops_test_common.duplicate_opinfoto create new OpInfo with new names and map each to one overload.