diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 61a8d32d..fa36a78e 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -82,6 +82,8 @@ _enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1, + _enums.DataType.INT2, + _enums.DataType.UINT2, ) ) @@ -300,6 +302,16 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) raise TypeError( f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}." ) + if dtype == _enums.DataType.INT2: + if array.dtype not in (np.int8, np.uint8, ml_dtypes.int2): + raise TypeError( + f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int2 (not {array.dtype}) for IR data type {dtype}." + ) + if dtype == _enums.DataType.UINT2: + if array.dtype not in (np.uint8, ml_dtypes.uint2): + raise TypeError( + f"The numpy array dtype must be uint8 or ml_dtypes.uint2 (not {array.dtype}) for IR data type {dtype}." + ) return try: @@ -347,6 +359,10 @@ def _maybe_view_np_array_with_ml_dtypes( return array.view(ml_dtypes.uint4) if dtype == _enums.DataType.FLOAT4E2M1: return array.view(ml_dtypes.float4_e2m1fn) + if dtype == _enums.DataType.INT2: + return array.view(ml_dtypes.int2) + if dtype == _enums.DataType.UINT2: + return array.view(ml_dtypes.uint2) return array @@ -365,7 +381,7 @@ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray: """Create a numpy array for the byte representation of the tensor. This function is used for serializing the tensor to bytes. It handles the - special cases for 4-bit data types and endianness. + special cases for 2-bit and 4-bit data types and endianness. """ array = tensor.numpy() if tensor.dtype in { @@ -375,6 +391,12 @@ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray: }: # Pack the array into int4 array = _type_casting.pack_4bitx2(array) + elif tensor.dtype in { + _enums.DataType.INT2, + _enums.DataType.UINT2, + }: + # Pack the array into int2 + array = _type_casting.pack_2bitx4(array) else: assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" if not _IS_LITTLE_ENDIAN: @@ -726,6 +748,8 @@ def _load(self): _enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1, + _enums.DataType.INT2, + _enums.DataType.UINT2, }: # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values dt = np.dtype(np.uint8).newbyteorder("<") @@ -1051,7 +1075,7 @@ def tofile(self, file) -> None: class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors - """A tensor that stores 4bit datatypes in packed format. + """A tensor that stores 2bit and 4bit datatypes in packed format. .. versionadded:: 0.1.2 """ @@ -1077,7 +1101,7 @@ def __init__( Args: value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object. The value MUST be packed in an integer dtype. - dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1. + dtype: The data type of the tensor. Must be one of INT2, UINT2, INT4, UINT4, FLOAT4E2M1. shape: The shape of the tensor. name: The name of the tensor. doc_string: The documentation string. @@ -1092,9 +1116,9 @@ def __init__( raise TypeError(f"Expected an array compatible object, got {type(value)}") self._shape = Shape(shape) self._shape.freeze() - if dtype.bitwidth != 4: + if dtype.bitwidth not in (2, 4): raise TypeError( - f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}" + f"PackedTensor only supports INT2, UINT2, INT4, UINT4, FLOAT4E2M1, but got {dtype}" ) self._dtype = dtype self._raw = value @@ -1104,6 +1128,8 @@ def __init__( value.dtype == ml_dtypes.float4_e2m1fn or value.dtype == ml_dtypes.uint4 or value.dtype == ml_dtypes.int4 + or value.dtype == ml_dtypes.uint2 + or value.dtype == ml_dtypes.int2 ): raise TypeError( f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. " diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 51a6eda3..e22fb2d4 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -55,8 +55,11 @@ def test_init_requires_type_when_value_is_not_np_array(self): ("float8e5m2", np.uint8, ir.DataType.FLOAT8E5M2), ("float8e5m2fnuz", np.uint8, ir.DataType.FLOAT8E5M2FNUZ), ("float8e8m0", np.uint8, ir.DataType.FLOAT8E8M0), + ("int2", np.int8, ir.DataType.INT2), + ("int2_uint8", np.uint8, ir.DataType.INT2), ("int4", np.int8, ir.DataType.INT4), ("int4_uint8", np.uint8, ir.DataType.INT4), + ("uint2", np.uint8, ir.DataType.UINT2), ("uint4", np.uint8, ir.DataType.UINT4), ("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1), ] @@ -146,6 +149,38 @@ def test_tobytes(self): tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) self.assertEqual(tensor.tobytes(), array.tobytes()) + def test_tobytes_returns_packed_data_for_int2(self): + array = np.array([-2, -1, 0, 1, 1, -2, 1], dtype=np.int8) + # Test array size not divisible by 4 + assert len(array) % 4 != 0 + tensor = _core.Tensor(array, dtype=ir.DataType.INT2) + # -2, -1, 0, 1 => [0b10, 0b11, 0b00, 0b01] => 0b01001110 = 0x4E + # 1, -2, 1, 0 (padding) => [0b01, 0b10, 0b01, 0b00] => 0b00011001 = 0x19 + self.assertEqual(tensor.tobytes(), b"\x4e\x19") + + def test_tobytes_returns_packed_data_for_int2_ml_dtypes(self): + array = np.array([-2, -1, 0, 1, 1, -2, 1], dtype=ml_dtypes.int2) + # Test array size not divisible by 4 + assert len(array) % 4 != 0 + tensor = _core.Tensor(array, dtype=ir.DataType.INT2) + self.assertEqual(tensor.tobytes(), b"\x4e\x19") + + def test_tobytes_returns_packed_data_for_uint2(self): + array = np.array([0, 1, 2, 3, 3, 2, 1], dtype=np.uint8) + # Test array size not divisible by 4 + assert len(array) % 4 != 0 + tensor = _core.Tensor(array, dtype=ir.DataType.UINT2) + # 0, 1, 2, 3 => 0b11100100 = 0xE4 + # 3, 2, 1, 0 (padding) => 0b00011011 = 0x1B + self.assertEqual(tensor.tobytes(), b"\xe4\x1b") + + def test_tobytes_returns_packed_data_for_uint2_ml_dtypes(self): + array = np.array([0, 1, 2, 3, 3, 2, 1], dtype=ml_dtypes.uint2) + # Test array size not divisible by 4 + assert len(array) % 4 != 0 + tensor = _core.Tensor(array, dtype=ir.DataType.UINT2) + self.assertEqual(tensor.tobytes(), b"\xe4\x1b") + def test_tobytes_returns_packed_data_for_int4(self): array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8) # Test odd sized array diff --git a/src/onnx_ir/_enums.py b/src/onnx_ir/_enums.py index 273fb8e8..832ee3bc 100644 --- a/src/onnx_ir/_enums.py +++ b/src/onnx_ir/_enums.py @@ -67,6 +67,8 @@ class DataType(enum.IntEnum): INT4 = 22 FLOAT4E2M1 = 23 FLOAT8E8M0 = 24 + INT2 = 25 + UINT2 = 26 @classmethod def from_numpy(cls, dtype: np.dtype) -> DataType: @@ -101,6 +103,10 @@ def from_numpy(cls, dtype: np.dtype) -> DataType: return DataType.INT4 if dtype.names == ("float4e2m1",): return DataType.FLOAT4E2M1 + if dtype.names == ("int2",): + return DataType.INT2 + if dtype.names == ("uint2",): + return DataType.UINT2 raise TypeError(f"Unsupported numpy data type: {dtype}") @classmethod @@ -329,6 +335,8 @@ def is_integer(self) -> bool: DataType.UINT64, DataType.UINT4, DataType.INT4, + DataType.INT2, + DataType.UINT2, } def is_signed(self) -> bool: @@ -354,6 +362,7 @@ def is_signed(self) -> bool: DataType.INT4, DataType.FLOAT4E2M1, DataType.FLOAT8E8M0, + DataType.INT2, } def is_string(self) -> bool: @@ -394,6 +403,8 @@ def __str__(self) -> str: DataType.INT4: 4, DataType.FLOAT4E2M1: 4, DataType.FLOAT8E8M0: 8, + DataType.INT2: 2, + DataType.UINT2: 2, } @@ -423,6 +434,8 @@ def __str__(self) -> str: np.dtype(ml_dtypes.int4): DataType.INT4, np.dtype(ml_dtypes.uint4): DataType.UINT4, np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1, + np.dtype(ml_dtypes.int2): DataType.INT2, + np.dtype(ml_dtypes.uint2): DataType.UINT2, } # ONNX DataType to Numpy dtype. @@ -442,12 +455,14 @@ def __str__(self) -> str: DataType.FLOAT4E2M1: "f4e2m1", DataType.COMPLEX64: "c64", DataType.COMPLEX128: "c128", + DataType.INT2: "i2", DataType.INT4: "i4", DataType.INT8: "i8", DataType.INT16: "i16", DataType.INT32: "i32", DataType.INT64: "i64", DataType.BOOL: "b8", + DataType.UINT2: "u2", DataType.UINT4: "u4", DataType.UINT8: "u8", DataType.UINT16: "u16", diff --git a/src/onnx_ir/_enums_test.py b/src/onnx_ir/_enums_test.py index 474aec6c..e8649a4b 100644 --- a/src/onnx_ir/_enums_test.py +++ b/src/onnx_ir/_enums_test.py @@ -39,6 +39,10 @@ def test_enums_are_the_same_as_spec(self): self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1) if hasattr(onnx.TensorProto, "FLOAT8E8M0"): self.assertEqual(_enums.DataType.FLOAT8E8M0, onnx.TensorProto.FLOAT8E8M0) + if hasattr(onnx.TensorProto, "INT2"): + self.assertEqual(_enums.DataType.INT2, onnx.TensorProto.INT2) + if hasattr(onnx.TensorProto, "UINT2"): + self.assertEqual(_enums.DataType.UINT2, onnx.TensorProto.UINT2) self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED) @parameterized.parameterized.expand( @@ -75,6 +79,8 @@ def test_enums_are_the_same_as_spec(self): ("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4), ("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1), ("float8e8m0", np.dtype(ml_dtypes.float8_e8m0fnu), _enums.DataType.FLOAT8E8M0), + ("int2", np.dtype(ml_dtypes.int2), _enums.DataType.INT2), + ("uint2", np.dtype(ml_dtypes.uint2), _enums.DataType.UINT2), ] ) def test_from_numpy_takes_np_dtype_and_returns_data_type( diff --git a/src/onnx_ir/_type_casting.py b/src/onnx_ir/_type_casting.py index 23f563ac..5ada83bb 100644 --- a/src/onnx_ir/_type_casting.py +++ b/src/onnx_ir/_type_casting.py @@ -48,3 +48,42 @@ def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArr result = result[:-1] result.resize(dims, refcheck=False) return result + + +def pack_2bitx4(array: np.ndarray) -> npt.NDArray[np.uint8]: + """Convert a numpy array to flatten, packed int2/uint2. Elements must be in the correct range.""" + # Create a 1D copy + array_flat = array.ravel().view(np.uint8).copy() + size = array.size + padding = (4 - (size % 4)) % 4 + if padding > 0: + array_flat.resize([size + padding], refcheck=False) + array_flat &= 0x03 + array_flat[1::4] <<= 2 + array_flat[2::4] <<= 4 + array_flat[3::4] <<= 6 + return array_flat[0::4] | array_flat[1::4] | array_flat[2::4] | array_flat[3::4] # type: ignore[return-type] + + +def unpack_2bitx4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]: + """Convert a packed uint2 array to unpacked uint2 array represented as uint8. + + Args: + data: A numpy array. + dims: The dimensions are used to reshape the unpacked buffer. + + Returns: + A numpy array of int8/uint8 reshaped to dims. + """ + assert data.dtype == np.uint8, "Input data must be of type uint8" + result = np.empty([data.size * 4], dtype=data.dtype) + result[0::4] = data & np.uint8(0x03) + result[1::4] = (data & np.uint8(0x0C)) >> np.uint8(2) + result[2::4] = (data & np.uint8(0x30)) >> np.uint8(4) + result[3::4] = (data & np.uint8(0xC0)) >> np.uint8(6) + total_elements = int(np.prod(dims)) + if result.size > total_elements: + # handle padding due to element count not being a multiple of 4 + result = result[:total_elements] + result.resize(dims, refcheck=False) + return result diff --git a/src/onnx_ir/serde.py b/src/onnx_ir/serde.py index 82a6eda3..6618620e 100644 --- a/src/onnx_ir/serde.py +++ b/src/onnx_ir/serde.py @@ -390,6 +390,10 @@ def numpy(self) -> np.ndarray: return _type_casting.unpack_4bitx2( np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape ).view(dtype.numpy()) + if dtype.bitwidth == 2: + return _type_casting.unpack_2bitx4( + np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape + ).view(dtype.numpy()) return np.frombuffer( self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<") ).reshape(shape) @@ -408,9 +412,11 @@ def numpy(self) -> np.ndarray: _enums.DataType.FLOAT8E8M0, _enums.DataType.INT16, _enums.DataType.INT32, + _enums.DataType.INT2, _enums.DataType.INT4, _enums.DataType.INT8, _enums.DataType.UINT16, + _enums.DataType.UINT2, _enums.DataType.UINT4, _enums.DataType.UINT8, }, f"Unsupported dtype {dtype} for int32_data" @@ -426,6 +432,10 @@ def numpy(self) -> np.ndarray: return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view( dtype.numpy() ) + if dtype.bitwidth == 2: + return _type_casting.unpack_2bitx4(array.astype(np.uint8), shape).view( + dtype.numpy() + ) raise ValueError( f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}" ) @@ -507,11 +517,13 @@ def tobytes(self) -> bytes: _enums.DataType.FLOAT8E5M2, _enums.DataType.FLOAT8E5M2FNUZ, _enums.DataType.FLOAT8E8M0, + _enums.DataType.INT2, _enums.DataType.INT4, + _enums.DataType.UINT2, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1, }: - # uint4 and int4 values are already packed, even when stored as int32 + # uint2, uint4, int2 and int4 values are already packed, even when stored as int32 # so we don't need to pack them again return array.astype(_little_endian_dtype(np.uint8)).tobytes() assert self.dtype == _enums.DataType.INT32 diff --git a/src/onnx_ir/serde_test.py b/src/onnx_ir/serde_test.py index 1b78ed80..78ebf329 100644 --- a/src/onnx_ir/serde_test.py +++ b/src/onnx_ir/serde_test.py @@ -282,9 +282,21 @@ def test_tensor_proto_tensor_float8(self, _: str, dtype: int, np_dtype): ("INT32", onnx.TensorProto.INT32), ("INT64", onnx.TensorProto.INT64), ("INT4", onnx.TensorProto.INT4), + ("INT2", 25), # INT2 value ] ) def test_tensor_proto_tensor_int(self, _: str, dtype: int): + # INT2 is not yet supported in ONNX numpy_helper, so we handle it specially + if dtype == 25: # INT2 + # Create tensor proto manually since ONNX helper might not support this type yet + data_array = np.array([[-1, 0, 1]], dtype=ml_dtypes.int2) + # Create an IR tensor which will pack the data correctly + ir_tensor = ir.Tensor(data_array) + tensor_proto = serde.to_proto(ir_tensor) + tensor = serde.TensorProtoTensor(tensor_proto) + np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.int2), data_array) + return # Skip remaining tests for INT2 as ONNX doesn't support it yet + tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 4], [-1, 0, 1, 8]) tensor = serde.TensorProtoTensor(tensor_proto) expected_array = onnx.numpy_helper.to_array( @@ -311,9 +323,21 @@ def test_tensor_proto_tensor_int(self, _: str, dtype: int): ("UINT32", onnx.TensorProto.UINT32), ("UINT64", onnx.TensorProto.UINT64), ("UINT4", onnx.TensorProto.UINT4), + ("UINT2", 26), # UINT2 value ] ) def test_tensor_proto_tensor_uint(self, _: str, dtype: int): + # UINT2 is not yet supported in ONNX numpy_helper, so we handle it specially + if dtype == 26: # UINT2 + # Create tensor proto manually since ONNX helper might not support this type yet + data_array = np.array([[0, 1, 2, 3]], dtype=ml_dtypes.uint2) + # Create an IR tensor which will pack the data correctly + ir_tensor = ir.Tensor(data_array) + tensor_proto = serde.to_proto(ir_tensor) + tensor = serde.TensorProtoTensor(tensor_proto) + np.testing.assert_array_equal(tensor.numpy().view(ml_dtypes.uint2), data_array) + return # Skip remaining tests for UINT2 as ONNX doesn't support it yet + tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 3], [0, 1, 8]) tensor = serde.TensorProtoTensor(tensor_proto) expected_array = onnx.numpy_helper.to_array(tensor_proto) @@ -396,7 +420,9 @@ def test_tensor_proto_tensor_empty_tensor(self): ("FLOAT8E5M2", ir.DataType.FLOAT8E5M2), ("FLOAT8E5M2FNUZ", ir.DataType.FLOAT8E5M2FNUZ), ("FLOAT8E8M0", ir.DataType.FLOAT8E8M0), + ("UINT2", ir.DataType.UINT2), ("UINT4", ir.DataType.UINT4), + ("INT2", ir.DataType.INT2), ("INT4", ir.DataType.INT4), ("FLOAT4E2M1", ir.DataType.FLOAT4E2M1), ], diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 8be17108..ed1cea25 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -80,6 +80,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType: if hasattr(torch, "float8_e8m0fnu"): # torch.float8_e8m0fnu is available in PyTorch 2.7+ _TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0 + if hasattr(torch, "int2"): + _TORCH_DTYPE_TO_ONNX[torch.int2] = ir.DataType.INT2 + if hasattr(torch, "uint2"): + _TORCH_DTYPE_TO_ONNX[torch.uint2] = ir.DataType.UINT2 if dtype not in _TORCH_DTYPE_TO_ONNX: raise TypeError( @@ -121,6 +125,10 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype: if hasattr(torch, "float8_e8m0fnu"): # torch.float8_e8m0fnu is available in PyTorch 2.7+ _ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu + if hasattr(torch, "int2"): + _ONNX_DTYPE_TO_TORCH[ir.DataType.INT2] = torch.int2 + if hasattr(torch, "uint2"): + _ONNX_DTYPE_TO_TORCH[ir.DataType.UINT2] = torch.uint2 if dtype not in _ONNX_DTYPE_TO_TORCH: if dtype == ir.DataType.FLOAT8E8M0: