From 3322d21eaca527a00cd023a43a4a1a72d4c67355 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Tue, 18 Nov 2025 17:24:23 -0600 Subject: [PATCH] add block-wise scaled int8 quantization based on QuantizedLayout mechanism add more tests by comparing with manual torch implementation add perf benchmarks fix errors caused by merging default no output quant fix unittest --- QUANTIZATION.md | 14 +- comfy/float.py | 2 + comfy/int8_kernels.py | 1194 ++++++++ comfy/ops.py | 8 +- comfy/quant_ops.py | 979 ++++++- comfy/weight_adapter/boft.py | 3 +- comfy/weight_adapter/glora.py | 3 +- comfy/weight_adapter/loha.py | 4 +- comfy/weight_adapter/lokr.py | 3 +- comfy/weight_adapter/lora.py | 3 +- comfy/weight_adapter/oft.py | 3 +- tests-unit/comfy_quant/test_quant_registry.py | 2523 ++++++++++++++++- 12 files changed, 4703 insertions(+), 36 deletions(-) create mode 100644 comfy/int8_kernels.py diff --git a/QUANTIZATION.md b/QUANTIZATION.md index 1693e13f32e2..177651a20235 100644 --- a/QUANTIZATION.md +++ b/QUANTIZATION.md @@ -124,6 +124,10 @@ We define 4 possible scaling parameters that should cover most recipes in the ne | Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | |--------|---------------|--------------|----------------|-----------------|-------------| | float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | +| int8_blockwise | int8 | float32 (per-block) | - | - | - | + +For int8_blockwise with block_size=128 and weight shape (N, K): +- weight_scale shape: (N//128, K//128) You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). @@ -131,7 +135,9 @@ You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). The metadata stored alongside the checkpoint contains: - **format_version**: String to define a version of the standard -- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. +- **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with: + - **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS` + - **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise) Example: ```json @@ -139,9 +145,9 @@ Example: "_quantization_metadata": { "format_version": "1.0", "layers": { - "model.layers.0.mlp.up_proj": "float8_e4m3fn", - "model.layers.0.mlp.down_proj": "float8_e4m3fn", - "model.layers.1.mlp.up_proj": "float8_e4m3fn" + "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"}, + "model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128}, + "model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256} } } } diff --git a/comfy/float.py b/comfy/float.py index 521316fd2fac..3caa22909d45 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0): return value.to(dtype=torch.float16) if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) + if dtype == torch.int8: + return value.to(dtype=torch.int8) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) diff --git a/comfy/int8_kernels.py b/comfy/int8_kernels.py new file mode 100644 index 000000000000..8573bd5de4ee --- /dev/null +++ b/comfy/int8_kernels.py @@ -0,0 +1,1194 @@ +import torch +import triton +import triton.language as tl +from triton import Config +from typing import Tuple + + +""" +simplified explanation of the scaled int8 matmul algorithm +adopted from deepseek scaled FP8 matmul and jetfire paper +https://arxiv.org/abs/2403.12422 +https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py + + N dimension → + INT8 weights scaler per block + ┌-----┬-----┬─────┬─────┐ ┌-----┬-----┬─────┬─────┐ + : b00 : b01 : b02 | b03 | : : : | | + ├-----┼-----┼─────┼─────┤ :b_s00:b_s10:b_s20|b_s30| + K : b10 : b11 : b12 | b13 | : : : | | + dim ├-----┼-----┼─────┼─────┤ ├-----┼-----┼─────┼─────┤ + ↓ | b20 | b21 | b22 | b23 | | | | | | + ├─────┼─────┼─────┼─────┤ |b_s01|b_s11|b_s21|b_s31| + | b30 | b31 | b32 | b33 | | | | | | + └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ + ┌-----┬-----┐ + : b00 : b01 : + ├─── blk ───┤ ├-----┼-----┤ + : b10 : b11 : + K dimension → └-----┴-----┘ + INT8 activations + ┌-----┬-----┬─────┬─────┐ ┌-----┬-----┐ ┌-----┬-----┐ ┌-----------┐ ┌-----┬-----┐ ┌-----┬-----┐ + : a00 : a01 : a02 | a03 | : a00 : a01 : : @ : @ : : a_s00 : : : : :acc00:acc01: + ├-----┼-----┼─────┼─────┤ ├-----┼-----┤ ├-----┼-----┤ * ├-----------┤ * :b_s00:b_s10: = ├-----┼-----┤ + M : a10 : a11 : a12 | a13 | : a10 : a11 : : @ : @ : : a_s10 : : : : :acc10:acc11: +dim ├-----┼-----┼─────┼─────┤ └-----┴-----┘ └-----┴-----┘ └-----------┘ └-----┴-----┘ └-----┴-----┘ + ↓ | a20 | a21 | a22 | a23 | INT8 matmul acc in INT32 rescale the FP32 intermediate accumulate + ├─────┼─────┼─────┼─────┤ then cast to FP32 "rank 1" hadamard scaler intermediate + | a30 | a31 | a32 | a33 | + └─────┴─────┴─────┴─────┘ + scaler per block + ┌-----------┬───────────┐ + : a_s00 : a_s01 | + ├-----------┼───────────┤ + : a_s10 : a_s11 | + ├-----------┼───────────┤ + | a_s20 | a_s21 | + ├───────────┼───────────┤ + | a_s30 | a_s31 | + └───────────┴───────────┘ +""" + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + amax = tl.max(tl.abs(x)) # reduction + # amax = tl.maximum(amax, 1e-4) # clamp to 1e-4 + s = amax / 127.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.int8`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + y = torch.empty_like(x, dtype=torch.int8) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + # Grid size should match number of scale elements (one program per block) + # Each program processes block_size elements and writes one scale value + num_programs = s.numel() # Number of blocks = number of scale elements + grid = lambda meta: (num_programs,) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def act_dequant_kernel(x_ptr, s_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes the input tensor `x_ptr` using scaling factors from `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the quantized input tensor. + s_ptr (triton.Pointer): Pointer to the scaling factors. + y_ptr (triton.Pointer): Pointer to the output tensor where dequantized values will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.load(s_ptr + pid) + y = x * s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + + +def act_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128, output_dtype: torch.dtype = None +) -> torch.Tensor: + """ + Dequantizes the activation tensor `x` using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized activation tensor. Must be contiguous and its last dimension size must be divisible by `block_size`. + s (torch.Tensor): The scale tensor with shape (*batch_dims, last_dim // block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + output_dtype (torch.dtype, optional): Target dtype for output. Defaults to torch.get_default_dtype(). + + Returns: + torch.Tensor: The dequantized activation tensor of the same shape as `x`. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + if output_dtype is None: + output_dtype = torch.get_default_dtype() + + y = torch.empty_like(x, dtype=output_dtype) + # Grid size should match number of scale elements (one program per block) + num_programs = s.numel() # Number of blocks = number of scale elements + grid = lambda meta: (num_programs,) + act_dequant_kernel[grid](x, s, y, BLOCK_SIZE=block_size) + return y + + +@triton.jit +def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Quantizes weights using block-wise quantization. + + Args: + x_ptr (tl.pointer): Pointer to the input weights. + y_ptr (tl.pointer): Pointer to the output buffer for quantized weights. + s_ptr (tl.pointer): Pointer to the output buffer for scaling factors. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for quantization. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + # Compute per-block absolute maximum + amax = tl.max(tl.abs(x)) + s = amax / 127.0 + #s = tl.maximum(s, 1e-8) # Prevent division by zero + + # Quantize + y = x / s + #y = tl.maximum(tl.minimum(y, 127.0), -127.0) # Clamp + y = y.to(y_ptr.dtype.element_ty) + + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + +def weight_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the weight tensor using block-wise quantization. + + Args: + x (torch.Tensor): The weight tensor of shape (M, N). + block_size (int, optional): The block size to use for quantization. Defaults to 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.int8`. + - A tensor of scaling factors with shape (M//block_size, N//block_size) and dtype `torch.float32`. + + Raises: + AssertionError: If `x` is not contiguous or if its dimensions are not 2. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + M, N = x.size() + assert M % block_size == 0 and N % block_size == 0, \ + f"Dimensions must be divisible by block_size={block_size}, got shape {x.shape}" + + y = torch.empty_like(x, dtype=torch.int8) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128, output_dtype: torch.dtype = None +) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + output_dtype (torch.dtype, optional): Target dtype for output. Defaults to torch.get_default_dtype(). + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + + if output_dtype is None: + output_dtype = torch.get_default_dtype() + + y = torch.empty_like(x, dtype=output_dtype) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +# matmul intermediate block size is hardcoded to 128 +int8_gemm_configs = [ + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [128, 256] # >= 128 for consistency with out_block_size + for block_n in [128, 256] # >= 128 required for out_block_size compatibility + for num_stages in [3, 4, 5] +] + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.jit +def int8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + Performs a matrix multiplication operation on INT8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + # Address calculation: scale[pid_n, i] = base + pid_n * stride + i + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Create accumulators outside the loop for better performance + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + # Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + # INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling + dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32 + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.jit +def int8_gemm_addmm_kernel( + a_ptr, + b_ptr, + c_ptr, + bias_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with bias addition (addmm). + Computes: C = A @ B + bias + + This kernel fuses the bias addition into the matmul, avoiding an extra memory write/read cycle. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A (INT8). + b_ptr (tl.tensor): Pointer to the second input matrix B (INT8). + c_ptr (tl.tensor): Pointer to the output matrix C. + bias_ptr (tl.tensor): Pointer to the bias vector (1D, length N). + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + HAS_BIAS (tl.constexpr): Whether bias is provided. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + # Address calculation: scale[pid_n, i] = base + pid_n * stride + i + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + # Create accumulators outside the loop for better performance + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + # Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + # INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling + dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32 + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Add bias if provided (fused operation) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_n[None, :] + bias = tl.load(bias_ptrs, mask=offs_n[None, :] < N, other=0.0) + accumulator += bias # Broadcast bias across M dimension + + # Store result + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def int8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using INT8 precision. + + Expected tensor shapes: + - a: [..., K] where ... can be any batch dimensions + - b: [N, K] (weight matrix in standard format, kernel transposes internally) + - a_s: [..., K//block_size] + - b_s: [N//block_size, K//block_size] + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix [N, K], must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert ( + a_s.is_contiguous() and b_s.is_contiguous() + ), "Scaling factor tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + # b has shape [N, K], extract N from first dimension + N = b.shape[0] + + # Validate shapes + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + + # Output tensor (same batch shape as input, last dim = N) + # let's use float16 as output dtype + c = a.new_empty(*a.size()[:-1], N, dtype=torch.float16) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + int8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128) + return c + + +def int8_addmm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + bias: torch.Tensor = None +): + """ + Fused INT8 matrix multiplication with bias addition (addmm). + Computes: output = (a @ b) + bias + + Expected tensor shapes: + - a: [..., K] where ... can be any batch dimensions + - b: [N, K] (weight matrix in standard format, kernel transposes internally) + - a_s: [..., K//block_size] + - b_s: [N//block_size, K//block_size] + - bias: [N] (optional) + + This is more efficient than separate matmul + bias add operations as it: + 1. Avoids an extra memory write/read cycle + 2. Fuses the bias addition into the matmul kernel + 3. Better utilizes GPU memory bandwidth + + Args: + a (torch.Tensor): The first input matrix (INT8), must be contiguous. + a_s (torch.Tensor): The scaling factors for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix (INT8) [N, K], must be contiguous. + b_s (torch.Tensor): The scaling factors for the second input matrix, must be contiguous. + bias (torch.Tensor, optional): The bias vector (1D, length N). If None, only matmul is performed. + + Returns: + torch.Tensor: The result of the fused matrix multiplication and bias addition. + + Example: + >>> a_int8, a_scale = act_quant(input_tensor, block_size=128) + >>> b_int8, b_scale = weight_quant(weight_tensor, block_size=128) + >>> bias = torch.randn(output_features) + >>> output = int8_addmm(a_int8, a_scale, b_int8, b_scale, bias) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert ( + a_s.is_contiguous() and b_s.is_contiguous() + ), "Scaling factor tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + # b has shape [N, K], extract N from first dimension + N = b.shape[0] + + # Validate shapes + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + + # Output tensor (same batch shape as input, last dim = N) + # let's use float16 as output dtype + c = a.new_empty(*a.size()[:-1], N, dtype=torch.float16) + + # Handle bias + has_bias = bias is not None + if has_bias: + assert bias.is_contiguous(), "Bias tensor must be contiguous" + assert bias.dim() == 1 and bias.size(0) == N, \ + f"Bias must be 1D with length {N}, got shape {bias.shape}" + bias_ptr = bias + else: + # Create a dummy pointer (won't be used due to HAS_BIAS=False) + bias_ptr = c + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + int8_gemm_addmm_kernel[grid]( + a, b, c, bias_ptr, a_s, b_s, M, N, K, HAS_BIAS=has_bias, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + return c + + +# ============================================================================== +# Fused INT8 GEMM + Quantization Kernels +# ============================================================================== +# +# Architecture Overview: +# ---------------------- +# 1. Kernels compute matmul and quantize PER-ROW for activation format +# - Each row gets its own scale for the N-range of the tile +# - Kernel output: c_scale shape is (M, N/BLOCK_SIZE_N) +# - BLOCK_SIZE_M, BLOCK_SIZE_N are tile sizes from autotuner (e.g., 16-64, 32-128) +# - This matches activation quantization: per-row, block-wise along N +# +# 2. Wrapper functions convert to final activation format +# - Kernel output: (M, N/BLOCK_SIZE_N) +# - Target format: (*batch_dims, N/out_block_size) +# - If BLOCK_SIZE_N == out_block_size: already correct, just reshape +# - If BLOCK_SIZE_N != out_block_size: replicate or merge scales +# +# 3. Benefits: +# - Accurate: per-row scales match activation quantization format +# - Efficient: single max reduction per row per tile +# - Compatible: direct output in activation format +# - Better precision: each row has independent scales +# +# ============================================================================== + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.heuristics({ + 'NUM_BLOCKS': lambda args: args["BLOCK_SIZE_N"] // args["out_block_size"], +}) +@triton.jit +def int8_gemm_quant_kernel( + a_ptr, + b_ptr, + c_ptr, + c_s_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + out_block_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_BLOCKS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with output quantization. + Computes: C_int8, C_scale = quantize(A @ B) + + This kernel fuses matmul and block-wise quantization in a single pass. + Quantizes at out_block_size granularity (like act_quant_kernel). + + Args: + a_ptr: Pointer to INT8 activations + b_ptr: Pointer to INT8 weights + c_ptr: Pointer to INT8 output + c_s_ptr: Pointer to output scales (shape: M x N/out_block_size) + a_s_ptr: Pointer to activation scales + b_s_ptr: Pointer to weight scales + M: Number of rows in A and C + N: Number of columns in B and C + K: Inner dimension (columns in A, rows in B) + out_block_size: Block size for output quantization + BLOCK_SIZE_M/N/K: Tile sizes for matmul + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + dot_prod = tl.dot(a, b, out_dtype=tl.int32) + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Quantize in activation format: per-row, block-wise at out_block_size granularity + # Reshape accumulator to separate blocks: (BLOCK_SIZE_M, BLOCK_SIZE_N) -> (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) + accumulator_reshaped = tl.reshape(accumulator, (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size)) + + # Compute max per block: reduce over out_block_size dimension + # Shape: (BLOCK_SIZE_M, NUM_BLOCKS) + block_max = tl.max(tl.abs(accumulator_reshaped), axis=2) + block_scale = tl.maximum(block_max / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_SIZE_M, NUM_BLOCKS) -> (BLOCK_SIZE_M, NUM_BLOCKS, 1) + block_scale_broadcast = tl.reshape(block_scale, (BLOCK_SIZE_M, NUM_BLOCKS, 1)) + + # Quantize: accumulator -> int8 + quantized = accumulator_reshaped / block_scale_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(c_ptr.dtype.element_ty) + + # Reshape back to 2D: (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) -> (BLOCK_SIZE_M, BLOCK_SIZE_N) + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Store quantized output + offs_m_actual = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_actual = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_m_actual[:, None] < M) & (offs_n_actual[None, :] < N) + c_ptrs = c_ptr + offs_m_actual[:, None] * N + offs_n_actual[None, :] + tl.store(c_ptrs, quantized_int8, mask=mask) + + # Store scales: (BLOCK_SIZE_M, NUM_BLOCKS) scales for this tile + # Scale layout: (M, N//out_block_size) - matches activation format directly! + # This tile covers M range [pid_m*BLOCK_SIZE_M : (pid_m+1)*BLOCK_SIZE_M] + # N range [pid_n*BLOCK_SIZE_N : (pid_n+1)*BLOCK_SIZE_N] + # N block indices: [pid_n * NUM_BLOCKS : (pid_n+1) * NUM_BLOCKS] + n_scale_stride = N // out_block_size # Total number of N blocks + offs_m_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_scale = pid_n * NUM_BLOCKS + tl.arange(0, NUM_BLOCKS) + scale_ptrs = c_s_ptr + offs_m_scale[:, None] * n_scale_stride + offs_n_scale[None, :] + scale_mask = (offs_m_scale[:, None] < M) & (offs_n_scale[None, :] < n_scale_stride) + tl.store(scale_ptrs, block_scale, mask=scale_mask) + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.heuristics({ + 'NUM_BLOCKS': lambda args: args["BLOCK_SIZE_N"] // args["out_block_size"], +}) +@triton.jit +def int8_gemm_addmm_quant_kernel( + a_ptr, + b_ptr, + c_ptr, + c_s_ptr, + bias_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + out_block_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_BLOCKS: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with bias addition and output quantization. + Computes: C_int8, C_scale = quantize(A @ B + bias) + + This kernel fuses matmul, bias addition, and block-wise quantization. + Quantizes at out_block_size granularity (like act_quant_kernel). + + Args: + a_ptr: Pointer to INT8 activations + b_ptr: Pointer to INT8 weights + c_ptr: Pointer to INT8 output + c_s_ptr: Pointer to output scales (shape: M x N/out_block_size) + bias_ptr: Pointer to bias vector + a_s_ptr: Pointer to activation scales + b_s_ptr: Pointer to weight scales + M: Number of rows in A and C + N: Number of columns in B and C + K: Inner dimension + out_block_size: Block size for output quantization + BLOCK_SIZE_M/N/K: Tile sizes for matmul + HAS_BIAS: Whether bias is provided + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + dot_prod = tl.dot(a, b, out_dtype=tl.int32) + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Add bias if provided + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_n[None, :] + bias = tl.load(bias_ptrs, mask=offs_n[None, :] < N, other=0.0) + accumulator += bias + + # Quantize in activation format: per-row, block-wise at out_block_size granularity + # Reshape accumulator to separate blocks: (BLOCK_SIZE_M, BLOCK_SIZE_N) -> (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) + accumulator_reshaped = tl.reshape(accumulator, (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size)) + + # Compute max per block: reduce over out_block_size dimension + # Shape: (BLOCK_SIZE_M, NUM_BLOCKS) + block_max = tl.max(tl.abs(accumulator_reshaped), axis=2) + block_scale = tl.maximum(block_max / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_SIZE_M, NUM_BLOCKS) -> (BLOCK_SIZE_M, NUM_BLOCKS, 1) + block_scale_broadcast = tl.reshape(block_scale, (BLOCK_SIZE_M, NUM_BLOCKS, 1)) + + # Quantize: accumulator -> int8 + quantized = accumulator_reshaped / block_scale_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(c_ptr.dtype.element_ty) + + # Reshape back to 2D: (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) -> (BLOCK_SIZE_M, BLOCK_SIZE_N) + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Store quantized output + offs_m_actual = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_actual = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_m_actual[:, None] < M) & (offs_n_actual[None, :] < N) + c_ptrs = c_ptr + offs_m_actual[:, None] * N + offs_n_actual[None, :] + tl.store(c_ptrs, quantized_int8, mask=mask) + + # Store scales: (BLOCK_SIZE_M, NUM_BLOCKS) scales for this tile + # Scale layout: (M, N//out_block_size) - matches activation format directly! + # This tile covers M range [pid_m*BLOCK_SIZE_M : (pid_m+1)*BLOCK_SIZE_M] + # N range [pid_n*BLOCK_SIZE_N : (pid_n+1)*BLOCK_SIZE_N] + # N block indices: [pid_n * NUM_BLOCKS : (pid_n+1) * NUM_BLOCKS] + n_scale_stride = N // out_block_size # Total number of N blocks + offs_m_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_scale = pid_n * NUM_BLOCKS + tl.arange(0, NUM_BLOCKS) + scale_ptrs = c_s_ptr + offs_m_scale[:, None] * n_scale_stride + offs_n_scale[None, :] + scale_mask = (offs_m_scale[:, None] < M) & (offs_n_scale[None, :] < n_scale_stride) + tl.store(scale_ptrs, block_scale, mask=scale_mask) + + +def int8_gemm_quant( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + out_block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 GEMM with output quantization. + Computes: C_int8, C_scale = quantize(A @ B) + + This avoids materializing the full-precision intermediate result. + + The kernel produces scales in activation format directly: (*batch_dims, N/out_block_size). + + Args: + a: INT8 activations [..., K] + a_s: Activation scales [..., K//block_size] + b: INT8 weights [N, K] + b_s: Weight scales [N//block_size, K//block_size] + out_block_size: Block size for output quantization (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales in activation format) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + N = b.shape[0] + batch_shape = a.size()[:-1] + + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + assert N % out_block_size == 0, f"N={N} must be divisible by out_block_size={out_block_size}" + + # Allocate output tensors + c = a.new_empty(*batch_shape, N, dtype=torch.int8) + + # Allocate scales in activation format directly: (M, N//out_block_size) + n_blocks = N // out_block_size + c_s = a.new_empty(M, n_blocks, dtype=torch.float32) + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + int8_gemm_quant_kernel[grid]( + a, b, c, c_s, a_s, b_s, M, N, K, out_block_size, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + + # Reshape scales to match batch dimensions: (M, n_blocks) -> (*batch_dims, n_blocks) + if len(batch_shape) > 0: + c_s = c_s.reshape(*batch_shape, n_blocks) + + return c, c_s + + +def int8_addmm_quant( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + bias: torch.Tensor = None, + out_block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 addmm with output quantization. + Computes: C_int8, C_scale = quantize(A @ B + bias) + + This fuses matmul, bias addition, and quantization in a single kernel pass. + + The kernel produces scales in activation format directly: (*batch_dims, N/out_block_size). + + Args: + a: INT8 activations [..., K] + a_s: Activation scales [..., K//block_size] + b: INT8 weights [N, K] + b_s: Weight scales [N//block_size, K//block_size] + bias: Optional bias vector [N] + out_block_size: Block size for output quantization (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales in activation format) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + N = b.shape[0] + batch_shape = a.size()[:-1] + + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + assert N % out_block_size == 0, f"N={N} must be divisible by out_block_size={out_block_size}" + + # Allocate output tensors + c = a.new_empty(*batch_shape, N, dtype=torch.int8) + + # Allocate scales in activation format directly: (M, N//out_block_size) + n_blocks = N // out_block_size + c_s = a.new_empty(M, n_blocks, dtype=torch.float32) + + # Handle bias + has_bias = bias is not None + if has_bias: + assert bias.is_contiguous(), "Bias tensor must be contiguous" + assert bias.dim() == 1 and bias.size(0) == N, \ + f"Bias must be 1D with length {N}, got shape {bias.shape}" + bias_ptr = bias + else: + bias_ptr = c # Dummy pointer + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + int8_gemm_addmm_quant_kernel[grid]( + a, b, c, c_s, bias_ptr, a_s, b_s, M, N, K, out_block_size, HAS_BIAS=has_bias, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + + # Reshape scales to match batch dimensions: (M, n_blocks) -> (*batch_dims, n_blocks) + if len(batch_shape) > 0: + c_s = c_s.reshape(*batch_shape, n_blocks) + + return c, c_s + + +# ============================================================================== +# INT8 GELU Kernel +# ============================================================================== + +# Autotuning configs for GELU kernel +# Note: BLOCK_N must be >= quantization block_size (typically 128) and divisible by it +# BLOCK_M can be any size since we don't block in M dimension for activations +int8_gelu_configs = [ + Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_m in [64, 128, 256] + for block_n in [128, 256] # Must be >= block_size and divisible by it + for num_stages in [2, 3, 4] + for num_warps in [4, 8] +] + + +#@triton.autotune(configs=int8_gelu_configs, key=["M", "N"]) +@triton.heuristics({ + 'BLOCK_SM': lambda args: args["BLOCK_M"], # For activations, no blocking in M dimension + 'BLOCK_SN': lambda args: args["BLOCK_N"] // args["BLOCK_SIZE"], +}) +@triton.jit +def int8_gelu_kernel( + output_ptr, + output_scale_ptr, + input_ptr, + input_scale_ptr, + M, + N: tl.constexpr, + SM, + SN: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SM: tl.constexpr, + BLOCK_SN: tl.constexpr, +): + """ + Fused INT8 GELU with block-wise quantization. + + Computes: output_int8, output_scale = quantize(gelu(dequantize(input_int8, input_scale))) + + For activation quantization, we only block along the last dimension (N). + Each row gets its own set of scales along N. + + Scale tensor layout: + - Input scales: (M, N // BLOCK_SIZE) - one scale per row per block in N + - Within each tile (BLOCK_M x BLOCK_N), we load (BLOCK_M, BLOCK_N // BLOCK_SIZE) scales + + This kernel: + 1. Loads INT8 input and its block-wise scales + 2. Dequantizes to float + 3. Applies GELU activation + 4. Quantizes output back to INT8 with new block-wise scales + + Args: + output_ptr: Pointer to INT8 output tensor + output_scale_ptr: Pointer to output scales + input_ptr: Pointer to INT8 input tensor + input_scale_ptr: Pointer to input scales + M: Number of rows + N: Number of columns + SM: Number of rows in scale tensor (= M for activations) + SN: Number of scale blocks in N dimension (= N // BLOCK_SIZE) + BLOCK_SIZE: Quantization block size (e.g., 128) + BLOCK_M: Tile size in M dimension + BLOCK_N: Tile size in N dimension + BLOCK_SM: Number of rows per tile (= BLOCK_M for activations) + BLOCK_SN: Number of scale blocks per tile in N dimension (= BLOCK_N // BLOCK_SIZE) + """ + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_m = pid // NUM_BLOCK_N + pid_n = pid % NUM_BLOCK_N + + # Offsets for data + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load input data + input_ptrs = input_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + input_data = tl.load(input_ptrs, mask=mask, other=0).to(tl.int32) + + # Load input scales + # Scale dimensions: (SM, SN) where SM = M, SN = N // BLOCK_SIZE + # For this tile: load (BLOCK_M, BLOCK_N // BLOCK_SIZE) scales + offs_sm = pid_m * BLOCK_SM + tl.arange(0, BLOCK_SM) + offs_sn = pid_n * BLOCK_SN + tl.arange(0, BLOCK_SN) + scale_ptrs = input_scale_ptr + offs_sm[:, None] * SN + offs_sn[None, :] + scale_mask = (offs_sm[:, None] < SM) & (offs_sn[None, :] < SN) + input_scales = tl.load(scale_ptrs, mask=scale_mask, other=1.0) + + # Reshape for broadcasting + # Data: (BLOCK_M, BLOCK_N) -> (BLOCK_M, BLOCK_SN, BLOCK_SIZE) + # Scales: (BLOCK_M, BLOCK_SN) -> (BLOCK_M, BLOCK_SN, 1) + input_data = tl.reshape(input_data, (BLOCK_M, BLOCK_SN, BLOCK_SIZE)) + input_scales = tl.reshape(input_scales, (BLOCK_M, BLOCK_SN, 1)) + + # Dequantize + input_fp32 = input_data.to(tl.float32) * input_scales + + # Apply GELU: 0.5 * x * (1 + erf(x / sqrt(2))) + sqrt_2 = 1.41421356237 + erf_input = input_fp32 / sqrt_2 + erf_val = tl.math.erf(erf_input) + gelu_output = input_fp32 * 0.5 * (1.0 + erf_val) + + # Compute output scales per block + # Shape: (BLOCK_M, BLOCK_SN, BLOCK_SIZE) -> (BLOCK_M, BLOCK_SN) + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) # Reduce over BLOCK_SIZE dimension + output_scales = tl.maximum(max_val / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_M, BLOCK_SN) -> (BLOCK_M, BLOCK_SN, 1) + output_scales_broadcast = tl.reshape(output_scales, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize output + quantized = gelu_output / output_scales_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(tl.int8) + + # Reshape back to 2D + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_M, BLOCK_N)) + + # Store quantized output + output_ptrs = output_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(output_ptrs, quantized_int8, mask=mask) + + # Store output scales + output_scale_ptrs = output_scale_ptr + offs_sm[:, None] * SN + offs_sn[None, :] + tl.store(output_scale_ptrs, output_scales, mask=scale_mask) + + +def int8_gelu( + x: torch.Tensor, + s_x: torch.Tensor, + block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 GELU activation with block-wise quantization. + + Computes: y_int8, y_scale = quantize(gelu(dequantize(x, s_x))) + + This avoids materializing the full-precision intermediate result. + + Args: + x: INT8 input tensor of any shape + s_x: Input scales with shape (*batch_dims, last_dim // block_size) + block_size: Quantization block size (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales) + + Note: + The kernel requires tile sizes >= block_size. This is automatically + handled by the autotuner, which uses BLOCK_M, BLOCK_N >= 128. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert s_x.is_contiguous(), "Scale tensor must be contiguous" + assert x.size(-1) % block_size == 0, \ + f"Last dimension must be divisible by block_size={block_size}" + assert block_size == 128, \ + f"Only block_size=128 is currently supported in autotuner configs (got {block_size})" + + # Handle multi-dimensional tensors by reshaping to 2D + original_shape = x.shape + batch_shape = original_shape[:-1] + N = original_shape[-1] + + if x.dim() > 2: + x = x.reshape(-1, N) + s_x = s_x.reshape(-1, s_x.size(-1)) + + M = x.size(0) + SM = M # For activations, we don't block in M dimension + SN = N // block_size + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.int8) + s_y = torch.empty_like(s_x, dtype=torch.float32) + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + int8_gelu_kernel[grid]( + y, s_y, x, s_x, + M, N, SM, SN, + BLOCK_SIZE=block_size, BLOCK_M=128, BLOCK_N=128, BLOCK_SM=128 + ) + + # Reshape back to original batch dimensions + if len(batch_shape) > 0: + y = y.reshape(*batch_shape, N) + s_y = s_y.reshape(*batch_shape, SN) + + return y, s_y diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1710..9543334e6b85 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -599,11 +599,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + # Check for per-layer group_size override, otherwise use default from QUANT_ALGOS + layer_config = MixedPrecisionOps._layer_quant_config[layer_name] + group_size = layer_config.get("group_size", qconfig.get("group_size", None)) + layout_params = { 'scale': state_dict.pop(weight_scale_key, None), 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), + 'block_size': group_size, } + if qconfig.get("asymmetric_layout", False): + layout_params['is_weight'] = True if layout_params['scale'] is not None: manually_loaded_keys.append(weight_scale_key) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d2f3e7397839..ff07c76da931 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -6,6 +6,23 @@ _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} +# Try to import Triton-based INT8 kernels +try: + from .int8_kernels import ( + act_quant as triton_act_quant, + act_dequant as triton_act_dequant, + weight_quant as triton_weight_quant, + weight_dequant as triton_weight_dequant, + int8_gemm as triton_int8_gemm, + int8_addmm as triton_int8_addmm, + int8_gemm_quant as triton_int8_gemm_quant, + int8_addmm_quant as triton_int8_addmm_quant + ) + _HAS_TRITON_INT8 = True +except ImportError: + _HAS_TRITON_INT8 = False + logging.warning("Triton INT8 kernels not available, using PyTorch fallback") + def register_layout_op(torch_op, layout_type): """ @@ -212,9 +229,21 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return handler(func, args, kwargs) # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) + #if isinstance(args[0] if args else None, QuantizedTensor): + #logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}, args={args}") + + to_return = cls._dequant_and_fallback(func, args, kwargs) + + return to_return + + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self): + return self._qdata.is_contiguous() @classmethod def _dequant_and_fallback(cls, func, args, kwargs): @@ -229,14 +258,6 @@ def dequant_arg(arg): new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self): - return self._qdata.is_contiguous() # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) @@ -328,6 +349,17 @@ def generic_to_dtype_layout(func, args, kwargs): ) return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + @register_generic_util(torch.ops.aten.copy_.default) def generic_copy_(func, args, kwargs): @@ -347,18 +379,6 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True @@ -431,16 +451,261 @@ def dequantize(qdata, scale, orig_dtype, **kwargs): def get_plain_tensors(cls, qtensor): return qtensor._qdata, qtensor._layout_params['scale'] + +# ============================================================================== +# Block-Wise INT8 Layout + Operation Handlers +# ============================================================================== +class BlockWiseINT8Layout(QuantizedLayout): + """ + Block-wise INT8 quantization layout. + + Storage format: + - qdata: INT8 tensor (torch.int8) + - scale: Per-block scaling factors (float32) + - block_size: Size of quantization blocks (default 128) + - orig_dtype: Original dtype before quantization (for casting back) + - is_weight: Whether this is a weight tensor (affects blocking dimension) + + Asymmetric blocking: + - Weights: blocks partition along first dimension (M) and second dimension (N) + scale shape: (M//block_size, N//block_size) + - Activations: blocks partition along last dimension (K) + scale shape: (*batch_dims, K//block_size) + """ + + @classmethod + def quantize(cls, tensor, scale=None, block_size=128, is_weight=False, **kwargs): + """ + Quantize a tensor to INT8 with block-wise scaling. + + Args: + tensor: Input tensor to quantize + scale: Optional pre-computed scaling factors + block_size: Size of quantization blocks (default 128) + is_weight: If True, block along both dimensions (for weights) + If False, block along last dimension only (for activations) + + Returns: + Tuple of (quantized_data, layout_params) + """ + orig_dtype = tensor.dtype + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + if is_weight: + # Weight quantization: block-wise along both M and N dimensions + # Expected shape: (M, N) + assert tensor.dim() == 2, f"Weight tensor must be 2D, got shape {tensor.shape}" + M, N = tensor.shape + assert M % block_size == 0 and N % block_size == 0, \ + f"Dimensions must be divisible by block_size={block_size}, got shape {tensor.shape}" + + # Use Triton kernel if available AND tensor is on CUDA + if _HAS_TRITON_INT8 and scale is None and tensor.is_cuda: + try: + qdata, scale = triton_weight_quant(tensor, block_size=block_size) + except Exception as e: + # don't fall back, raise, for easier debugging + logging.warning(f"Triton weight_quant failed: {e}, falling back to PyTorch") + raise e + # qdata, scale = cls._weight_quantize_pytorch(tensor, block_size) + else: + qdata, scale = cls._weight_quantize_pytorch(tensor, block_size, scale) + + else: + # Activation quantization: block-wise along last dimension (K) + # Can handle any shape: (*batch_dims, K) + K = tensor.shape[-1] + assert K % block_size == 0, \ + f"Last dimension must be divisible by block_size={block_size}, got {K}" + + # Use Triton kernel if available AND tensor is on CUDA + # ignore input scale for now + # TODO: why do we need input scale? + if _HAS_TRITON_INT8 and tensor.is_cuda: + try: + qdata, scale = triton_act_quant(tensor, block_size=block_size) + except Exception as e: + logging.warning(f"Triton act_quant failed: {e}, falling back to PyTorch") + qdata, scale = cls._activation_quantize_pytorch(tensor, block_size) + else: + qdata, scale = cls._activation_quantize_pytorch(tensor, block_size, scale) + + layout_params = { + 'scale': scale.to(torch.float32), + 'block_size': block_size, + 'is_weight': is_weight, + 'orig_dtype': orig_dtype + } + + return qdata, layout_params + + @staticmethod + def _weight_quantize_pytorch(tensor, block_size, scale=None): + """PyTorch fallback for weight quantization""" + M, N = tensor.shape + # Reshape to (M//block_size, block_size, N//block_size, block_size) + tensor_blocked = tensor.reshape(M // block_size, block_size, N // block_size, block_size) + # Permute to (M//block_size, N//block_size, block_size, block_size) + tensor_blocked = tensor_blocked.permute(0, 2, 1, 3) + + if scale is None: + # Compute per-block absolute maximum + amax = tensor_blocked.abs().amax(dim=(-2, -1)) + scale = amax / 127.0 + scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype)) + + # Broadcast scale for division: (M//block_size, N//block_size, 1, 1) + scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1) + tensor_scaled = tensor_blocked / scale_broadcast + + # Clamp and convert to int8 + tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0) + qdata = tensor_scaled.to(torch.int8) + + # Reshape back to original shape + qdata = qdata.permute(0, 2, 1, 3).reshape(M, N) + return qdata, scale + + @staticmethod + def _activation_quantize_pytorch(tensor, block_size, scale=None): + """PyTorch fallback for activation quantization""" + K = tensor.shape[-1] + batch_shape = tensor.shape[:-1] + tensor_blocked = tensor.reshape(*batch_shape, K // block_size, block_size) + + if scale is None: + # Compute per-block absolute maximum + amax = tensor_blocked.abs().amax(dim=-1) + scale = amax / 127.0 + scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype)) + + # Broadcast scale for division + scale_broadcast = scale.unsqueeze(-1) + tensor_scaled = tensor_blocked / scale_broadcast + + # Clamp and convert to int8 + tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0) + qdata = tensor_scaled.to(torch.int8) + + # Reshape back to original shape + qdata = qdata.reshape(tensor.shape) + return qdata, scale + + @staticmethod + def dequantize(qdata, scale, block_size, is_weight=False, orig_dtype=None, output_dtype=None, **kwargs): + """ + Dequantize INT8 tensor back to original precision. + + Args: + qdata: Quantized INT8 tensor + scale: Per-block scaling factors + block_size: Size of quantization blocks + is_weight: Whether this is a weight tensor + orig_dtype: Target dtype for dequantization + + Returns: + Dequantized tensor in orig_dtype + """ + if not qdata.is_contiguous(): + qdata = qdata.contiguous() + if not scale.is_contiguous(): + scale = scale.contiguous() + + if is_weight: + # Weight dequantization + if _HAS_TRITON_INT8 and qdata.dim() == 2 and qdata.is_cuda: + try: + dequant = triton_weight_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype) + return dequant + except Exception as e: + logging.warning(f"Triton weight_dequant failed: {e}, falling back to PyTorch") + raise e + + # PyTorch fallback + M, N = qdata.shape + # Ensure scale has the correct shape for weight dequantization + expected_scale_shape = (M // block_size, N // block_size) + if scale.shape != expected_scale_shape: + expected_numel = (M // block_size) * (N // block_size) + if scale.numel() == expected_numel: + scale = scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Weight dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}" + ) + qdata_blocked = qdata.reshape(M // block_size, block_size, N // block_size, block_size) + qdata_blocked = qdata_blocked.permute(0, 2, 1, 3) + scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1) + dequant = qdata_blocked.to(orig_dtype) * scale_broadcast + dequant = dequant.permute(0, 2, 1, 3).reshape(M, N) + else: + # Activation dequantization + if _HAS_TRITON_INT8 and qdata.is_cuda: + try: + dequant = triton_act_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype) + return dequant + except Exception as e: + logging.warning(f"Triton act_dequant failed: {e}, falling back to PyTorch") + raise e + + # PyTorch fallback + batch_shape = qdata.shape[:-1] + K = qdata.shape[-1] + # Ensure scale has the correct shape for activation dequantization + expected_scale_shape = (*batch_shape, K // block_size) + if scale.shape != expected_scale_shape: + expected_numel = 1 + for dim in expected_scale_shape: + expected_numel *= dim + if scale.numel() == expected_numel: + scale = scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Activation dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}" + ) + qdata_blocked = qdata.reshape(*batch_shape, K // block_size, block_size) + scale_broadcast = scale.unsqueeze(-1) + dequant = qdata_blocked.to(orig_dtype) * scale_broadcast + dequant = dequant.reshape(qdata.shape) + + return dequant + + @classmethod + def get_plain_tensors(cls, qtensor): + """ + Extract raw tensors for computation. + + Returns: + Tuple of (qdata, scale, block_size, is_weight) + """ + return ( + qtensor._qdata, + qtensor._layout_params['scale'], + qtensor._layout_params['block_size'], + qtensor._layout_params['is_weight'] + ) + + QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, "comfy_tensor_layout": "TensorCoreFP8Layout", }, + "int8_blockwise": { + "storage_t": torch.int8, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "BlockWiseINT8Layout", + "group_size": 128, # Default block size, + "asymmetric_layout": True, + }, } LAYOUTS = { "TensorCoreFP8Layout": TensorCoreFP8Layout, + "BlockWiseINT8Layout": BlockWiseINT8Layout, } @@ -570,3 +835,671 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + + +# ============================================================================== +# Block-Wise INT8 Operation Handlers +# ============================================================================== + +def _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None): + """ + PyTorch fallback for INT8 matrix multiplication: dequantize and use standard matmul. + + Args: + a_int8: INT8 activations, shape (*batch, K) + a_scale: Activation scales, shape (*batch, K//block_size) + b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format + b_scale: Weight scales, shape (N//block_size, K//block_size) + block_size: Block size for quantization + bias: Optional bias vector, shape (N,) + + Returns: + Output in float32, shape (*batch, N) + """ + K = a_int8.shape[-1] + batch_shape = a_int8.shape[:-1] + N = b_int8.shape[0] + + # Dequantize activations + # Ensure a_scale has the correct shape - it should be (*batch_shape, K // block_size) + expected_scale_shape = (*batch_shape, K // block_size) + if a_scale.shape != expected_scale_shape: + # Try to reshape if the number of elements matches + expected_numel = 1 + for dim in expected_scale_shape: + expected_numel *= dim + if a_scale.numel() == expected_numel: + a_scale = a_scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Scale shape mismatch: a_scale.shape={a_scale.shape}, expected {expected_scale_shape}. " + + f"a_int8.shape={a_int8.shape}, K={K}, block_size={block_size}" + ) + + a_blocked = a_int8.reshape(*batch_shape, K // block_size, block_size) + a_scale_broadcast = a_scale.unsqueeze(-1) + a_fp32 = a_blocked.to(torch.float32) * a_scale_broadcast + a_fp32 = a_fp32.reshape(*batch_shape, K) + + # Dequantize weights + # b_int8 is in (N, K) format (standard weight format), b_scale is in (N//block_size, K//block_size) format + expected_weight_scale_shape = (N // block_size, K // block_size) + if b_scale.shape != expected_weight_scale_shape: + # Try to reshape if the number of elements matches + expected_weight_numel = (N // block_size) * (K // block_size) + if b_scale.numel() == expected_weight_numel: + b_scale = b_scale.reshape(expected_weight_scale_shape) + else: + raise RuntimeError( + f"Weight scale shape mismatch: b_scale.shape={b_scale.shape}, expected {expected_weight_scale_shape}. " + + f"b_int8.shape={b_int8.shape}, N={N}, K={K}, block_size={block_size}" + ) + + # Dequantize weight: (N, K) -> blocks -> dequantize -> (N, K) + b_blocked = b_int8.reshape(N // block_size, block_size, K // block_size, block_size) + b_blocked = b_blocked.permute(0, 2, 1, 3) # (N//bs, K//bs, bs, bs) + b_scale_broadcast = b_scale.unsqueeze(-1).unsqueeze(-1) + b_fp32 = b_blocked.to(torch.float32) * b_scale_broadcast + b_fp32 = b_fp32.permute(0, 2, 1, 3).reshape(N, K) # Back to (N, K) + + output = torch.nn.functional.linear(a_fp32, b_fp32, bias) + return output + + +def _int8_gemm_triton_or_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None, out_quant=False): + """ + INT8 matrix multiplication with optional fused bias using Triton kernels or PyTorch fallback. + + Args: + a_int8: INT8 activations, shape (*batch, K) + a_scale: Activation scales, shape (*batch, K//block_size) + b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format + b_scale: Weight scales, shape (N//block_size, K//block_size) + block_size: Block size for quantization + bias: Optional bias vector, shape (N,) + out_quant: If True, return quantized output (INT8 + scales) instead of float + + Returns: + If out_quant=False: Output in float16/float32, shape (*batch, N) + If out_quant=True: Tuple of (output_int8, output_scale) + """ + K = a_int8.shape[-1] + batch_shape = a_int8.shape[:-1] + # b_int8 is weight in (N, K) format (standard PyTorch weight format) + N = b_int8.shape[0] + assert b_int8.shape[1] == K, f"Weight shape mismatch: expected b_int8.shape[1]={K}, got {b_int8.shape[1]}" + + # Try Triton kernel first (only if tensors are on CUDA) + if _HAS_TRITON_INT8 and a_int8.is_cuda: + try: + # int8_gemm/int8_addmm expects: (a, a_s, b, b_s, [bias]) + # a: (*batch, K), a_s: (*batch, K//block_size) + # b: (N, K), b_s: (N//block_size, K//block_size) + # Triton kernels transpose b internally + + # Reshape activations to 2D for int8_gemm + a_2d = a_int8.reshape(-1, K).contiguous() + a_scale_2d = a_scale.reshape(-1, a_scale.shape[-1]).contiguous() + + # Ensure weight tensors are contiguous + b_int8_c = b_int8.contiguous() + b_scale_c = b_scale.contiguous() + + # Call appropriate Triton kernel based on out_quant flag + if out_quant: + # Use fused matmul + quantization kernels + if bias is not None: + # Fused addmm + quantization + output_2d, output_scale_2d = triton_int8_addmm_quant( + a_2d, a_scale_2d, b_int8_c, b_scale_c, bias, out_block_size=block_size + ) + else: + # Fused gemm + quantization + output_2d, output_scale_2d = triton_int8_gemm_quant( + a_2d, a_scale_2d, b_int8_c, b_scale_c, out_block_size=block_size + ) + + # Reshape back to original batch shape + output = output_2d.reshape(*batch_shape, N) + output_scale = output_scale_2d.reshape(*batch_shape, N // block_size) + return output, output_scale + else: + # Standard float output + if bias is not None: + # Use fused addmm kernel + output_2d = triton_int8_addmm(a_2d, a_scale_2d, b_int8_c, b_scale_c, bias) + else: + # Use standard gemm kernel + output_2d = triton_int8_gemm(a_2d, a_scale_2d, b_int8_c, b_scale_c) + + # Reshape back to original batch shape + output = output_2d.reshape(*batch_shape, N) + return output + except Exception as e: + logging.warning(f"Triton int8_gemm/addmm failed: {e}, falling back to PyTorch") + raise e + + # Use PyTorch fallback + fallback_output = _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # If out_quant is requested, quantize the fallback output + if out_quant: + # Use PyTorch activation quantization on the output + from .int8_kernels import act_quant + try: + output_int8, output_scale = act_quant(fallback_output, block_size=block_size) + return output_int8, output_scale + except: + # Fallback to CPU quantization if Triton not available + output_int8, output_scale = BlockWiseINT8Layout._activation_quantize_pytorch( + fallback_output, block_size + ) + return output_int8, output_scale + + return fallback_output + + +@register_layout_op(torch.ops.aten.linear.default, "BlockWiseINT8Layout") +def int8_linear(func, args, kwargs): + """ + Block-wise INT8 linear operation handler with fused Triton kernel support. + + Supports: + - Both quantized input and weight (uses Triton int8_addmm with fused bias) + - Mixed precision (quantized weight, float input) + - Optional quantized output via out_dtype and out_quant parameters + """ + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Case 1: Both input and weight are quantized + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + + # Extract quantized data + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + # Verify configurations + assert not a_is_weight, "Input tensor should not be marked as weight" + assert b_is_weight, "Weight tensor should be marked as weight" + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output + + # Weight is already in (N, K) format (standard PyTorch weight format) + # Pass out_quant to _int8_gemm_triton_or_fallback for fused matmul+quant + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=bias, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Case 2: Fallback - dequantize and use standard linear + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_layout_op(torch.ops.aten.mm.default, "BlockWiseINT8Layout") +def int8_mm(func, args, kwargs): + """Block-wise INT8 matrix multiplication handler with Triton kernel support.""" + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + # Note: For mm, we expect both to be 2D + # If input is marked as weight (2D blocking), we need different logic + # For simplicity, dequantize if configurations don't match expected pattern + if a_is_weight or not b_is_weight: + logging.warning("INT8 mm: Unexpected tensor configurations, falling back to dequantization") + return func(input_tensor.dequantize(), weight.dequantize()) + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output (default: True) + + # Check if weight needs to be transposed to (N, K) format + # For mm: input is (M, K), weight should be (N, K) for the kernel + K = a_int8.shape[-1] + if b_int8.shape[0] == K and b_int8.shape[1] != K: + # Weight is in (K, N) format (transposed), transpose back to (N, K) + b_int8 = b_int8.t().contiguous() + b_scale = b_scale.t().contiguous() + + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=None, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Fallback + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + + +@register_layout_op(torch.ops.aten.addmm.default, "BlockWiseINT8Layout") +def int8_addmm(func, args, kwargs): + """ + Block-wise INT8 addmm operation handler with fused Triton kernel support. + addmm: out = beta * input + alpha * (mat1 @ mat2) + + This uses the fused int8_addmm kernel which combines matmul and bias addition + in a single pass for better performance. + + Args: + args[0]: bias tensor + args[1]: mat1 (input) + args[2]: mat2 (weight) + """ + bias = args[0] + input_tensor = args[1] + weight = args[2] + + # Case 1: Both input and weight are quantized + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + # Extract quantized data + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + # Verify configurations + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output + + # PyTorch's F.linear internally calls addmm(bias, input, weight.t()) + # So weight arrives in (K, N) format (transposed), need to transpose back to (N, K) + # Check if weight is transposed by comparing dimensions with input + K = a_int8.shape[-1] + if b_is_weight and b_int8.shape[0] == K: + # Weight is in (K, N) format (transposed), transpose back to (N, K) + # The transpose handler also transposed the scale, so we need to transpose it back too + b_int8 = b_int8.t().contiguous() + b_scale = b_scale.t().contiguous() + + # Use fused Triton kernel (combines matmul + bias + optional quant) + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=bias, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Fallback: dequantize and use standard addmm + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + + +@register_layout_op(torch.ops.aten.view.default, "BlockWiseINT8Layout") +def int8_view(func, args, kwargs): + """Handle view operations for INT8 tensors.""" + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + # For view, we need to be careful with block structure + # For safety, we'll allow these ops but note that they might break block alignment + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # Return new QuantizedTensor with same layout params + # Note: This assumes the transformation preserves block structure + return QuantizedTensor(transformed, "BlockWiseINT8Layout", input_tensor._layout_params) + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.t.default, "BlockWiseINT8Layout") +def int8_transpose(func, args, kwargs): + """Handle transpose operations for INT8 tensors.""" + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + # Transpose the quantized data + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # For weight tensors, we need to transpose the scale tensor as well + new_layout_params = input_tensor._layout_params.copy() + if new_layout_params.get('is_weight', False): + # Transpose the scale tensor to match the transposed weight + new_layout_params['scale'] = new_layout_params['scale'].t().contiguous() + + # Return new QuantizedTensor with updated layout params + return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params) + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.transpose.int, "BlockWiseINT8Layout") +def int8_transpose_int(func, args, kwargs): + """ + Handle general transpose operations for INT8 tensors. + + torch.transpose(input, dim0, dim1) swaps two dimensions. + + For BlockWiseINT8Layout: + - Activations: quantized along last dimension, scale shape is (*batch_dims, K//block_size) + If we swap the last dimension, we need to adjust scale handling + - Weights: quantized in 2D blocks (M, N), scale shape is (M//block_size, N//block_size) + If we swap dimensions on a 2D weight, transpose the scale tensor too + """ + input_tensor = args[0] + dim0 = args[1] if len(args) > 1 else kwargs.get('dim0', 0) + dim1 = args[2] if len(args) > 2 else kwargs.get('dim1', 1) + + if isinstance(input_tensor, QuantizedTensor): + # Transpose the quantized data + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # Copy layout params + new_layout_params = input_tensor._layout_params.copy() + is_weight = new_layout_params.get('is_weight', False) + + # Normalize dimensions to positive indices + ndim = plain_input.ndim + if dim0 < 0: + dim0 = ndim + dim0 + if dim1 < 0: + dim1 = ndim + dim1 + + # Handle scale tensor transposition + if is_weight: + # For weight tensors (2D with block-wise quantization in both dims) + # If we're transposing the two dimensions of a 2D tensor, transpose scales too + if ndim == 2 and set([dim0, dim1]) == {0, 1}: + # Transposing a 2D weight tensor (M, N) -> (N, M) + # Scale goes from (M//block_size, N//block_size) -> (N//block_size, M//block_size) + new_layout_params['scale'] = new_layout_params['scale'].t().contiguous() + else: + # For higher dimensional weight tensors or partial transposes, + # we may need more complex scale handling + # For now, log a warning as this is an uncommon case + logging.warning( + f"Transpose on weight tensor with dims ({dim0}, {dim1}) and shape {plain_input.shape}. " + f"Scale tensor may need adjustment for correct behavior." + ) + else: + # For activation tensors, block-wise quantization is along last dimension + # If we're swapping the last dimension, this changes the quantization structure + last_dim = ndim - 1 + if dim0 == last_dim or dim1 == last_dim: + # The last dimension is being moved, which affects quantization blocks + # This is a complex case - for safety, we could: + # 1. Dequantize, transpose, requantize (safest but slower) + # 2. Try to adjust scale tensor (complex, error-prone) + # For now, log a warning and proceed with transposing the scale tensor + # The scale tensor dimensions follow the input dimensions except the last + # which is divided by block_size + + # Determine how to transpose the scale tensor + # Scale shape is (*batch_dims, K//block_size) where K is the last dim of input + # When we transpose input dims, we need to transpose scale dims accordingly + # But the last scale dim always corresponds to the quantization blocks + + # Simple heuristic: if transposing involves last dim and input has 3+ dims, + # we transpose the corresponding scale dimensions + scale = new_layout_params['scale'] + if scale.ndim >= 2: + # Map input dimensions to scale dimensions + # Scale has shape (*batch_dims, K//block_size) + # If input has shape (*batch_dims, K), scale maps batch_dims directly + # and last dim is K//block_size + + # For transpose, if we swap dims d0 and d1 in input: + # - If d1 is last_dim (K), then in scale it's still last (K//block_size) + # - If d0 is last_dim, same applies + # - If neither is last_dim, transpose applies to batch dimensions + + if dim1 == last_dim: + # Swapping some batch dim with the last dim + # In scale, this means swapping that batch dim with last scale dim + scale_dim0 = dim0 # Same batch dimension + scale_dim1 = scale.ndim - 1 # Last dim of scale (K//block_size) + new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous() + elif dim0 == last_dim: + # Swapping last dim with some batch dim + scale_dim0 = scale.ndim - 1 # Last dim of scale + scale_dim1 = dim1 # Same batch dimension + new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous() + else: + # Swapping two batch dimensions (not involving last dim) + # Transpose the same dimensions in scale + new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous() + else: + logging.warning( + f"Transpose involves last dimension but scale tensor has shape {scale.shape}. " + f"Scale tensor may need adjustment." + ) + else: + # Transposing batch dimensions that don't affect the quantized dimension + # Transpose the same dimensions in scale tensor + scale = new_layout_params['scale'] + if scale.ndim > max(dim0, dim1): + new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous() + + # Return new QuantizedTensor with updated layout params + return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params) + + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.gelu.default, "BlockWiseINT8Layout") +def int8_gelu(func, args, kwargs): + """ + Block-wise INT8 GELU activation handler with fused Triton kernel support. + + Supports quantized input -> GELU -> quantized output in a single fused kernel. + This avoids materializing full-precision intermediate results. + """ + input_tensor = args[0] + + # Case 1: Input is quantized - use fused kernel + if isinstance(input_tensor, QuantizedTensor): + # Extract quantized data + qdata, scale, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + + orig_dtype = input_tensor._layout_params['orig_dtype'] + + # Determine if we should use Triton kernel + if _HAS_TRITON_INT8 and qdata.is_cuda: + try: + # Import the Triton kernel + from .int8_kernels import int8_gelu as triton_int8_gelu + + # Call fused kernel + output_qdata, output_scale = triton_int8_gelu(qdata, scale, block_size=block_size) + + # Wrap result in QuantizedTensor + layout_params = { + 'scale': output_scale.to(torch.float32), + 'block_size': block_size, + 'is_weight': False, # Output is always activation format + 'orig_dtype': orig_dtype + } + return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", layout_params) + + except Exception as e: + logging.warning(f"Triton int8_gelu failed: {e}, falling back to dequantization") + # Fall through to dequantization fallback + + # Fallback: dequantize, apply GELU, quantize + fp_input = input_tensor.dequantize() + fp_output = torch.nn.functional.gelu(fp_input) + + # Quantize output + output_qdata, output_layout_params = BlockWiseINT8Layout.quantize( + fp_output, + block_size=block_size, + is_weight=False + ) + output_layout_params['orig_dtype'] = orig_dtype + + return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", output_layout_params) + + # Case 2: Input is not quantized - use standard GELU + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.add_.Tensor, "BlockWiseINT8Layout") +def int8_add_(func, args, kwargs): + """ + Block-wise INT8 in-place addition handler for LoRA application. + + This operation is typically used when applying LoRA to weight matrices. + Since speed is not critical for this operation: + - If target is a weight: dequantize, add, then requantize as weight + - Otherwise: dequantize and fallback to regular addition + + Args: + args[0]: Target tensor (self) to be modified in-place + args[1]: Tensor to add + """ + target = args[0] + + if isinstance(target, QuantizedTensor): + # Extract quantization parameters + _, _, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(target) + + # Only handle the weight case specially + if is_weight: + other = args[1] + orig_dtype = target._layout_params['orig_dtype'] + + # Dequantize target + target_fp = target.dequantize() + + # Dequantize other if it's also quantized + if isinstance(other, QuantizedTensor): + other_fp = other.dequantize() + else: + other_fp = other + + # Perform addition + result_fp = target_fp + other_fp + + # Requantize as weight + result_qdata, result_layout_params = BlockWiseINT8Layout.quantize( + result_fp, + block_size=block_size, + is_weight=True + ) + result_layout_params['orig_dtype'] = orig_dtype + + # Update target in-place by copying the new quantized data + target._qdata.copy_(result_qdata) + target._layout_params['scale'].copy_(result_layout_params['scale']) + return target + + # For non-weight tensors or non-quantized tensors, use standard fallback + return QuantizedTensor._dequant_and_fallback(func, args, kwargs) + + +@register_layout_op(torch.ops.aten.to.dtype, "BlockWiseINT8Layout") +def int8_to_dtype(func, args, kwargs): + """ + Block-wise INT8 dtype conversion handler. + + This operation handles .to(dtype) calls on quantized tensors. + - If converting to torch.int8, do nothing (already in INT8 format) + - Otherwise, dequantize and fallback + + Args: + args[0]: Input tensor + args[1]: Target dtype + """ + input_tensor = args[0] + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype', None) + + if isinstance(input_tensor, QuantizedTensor): + # If target dtype is int8, the tensor is already in INT8 format + if target_dtype == torch.int8: + # No conversion needed, return as-is + return input_tensor + + # For any other dtype or non-quantized tensors, use standard fallback + return QuantizedTensor._dequant_and_fallback(func, args, kwargs) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd46be..e6389ce6b1c0 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, weight_decompose +from comfy.quant_ops import QuantizedTensor class BOFTAdapter(WeightAdapterBase): @@ -109,7 +110,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function((strength * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5845..60d66a71ec96 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, weight_decompose +from comfy.quant_ops import QuantizedTensor class GLoRAAdapter(WeightAdapterBase): @@ -87,7 +88,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d4033fd..7c6a34bbc9bf 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -4,7 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose - +from comfy.quant_ops import QuantizedTensor class HadaWeight(torch.autograd.Function): @staticmethod @@ -226,7 +226,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7f86..11f52ac282f2 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -9,6 +9,7 @@ weight_decompose, factorization, ) +from comfy.quant_ops import QuantizedTensor class LokrDiff(WeightAdapterTrainBase): @@ -214,7 +215,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b754..7ba244d56fe6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -10,6 +10,7 @@ pad_tensor_to_shape, tucker_weight_from_conv, ) +from comfy.quant_ops import QuantizedTensor class LoraDiff(WeightAdapterTrainBase): @@ -206,7 +207,7 @@ def calculate_weight( function, ) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab96353ca..a67e35a6abe3 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from comfy.quant_ops import QuantizedTensor class OFTDiff(WeightAdapterTrainBase): @@ -155,7 +156,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function((strength * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8026..f2689a7408cc 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -2,6 +2,8 @@ import torch import sys import os +import time +import gc # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -13,8 +15,16 @@ def has_gpu(): if not has_gpu(): args.cpu = True -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import ( + QuantizedTensor, + TensorCoreFP8Layout, + BlockWiseINT8Layout, + _int8_gemm_pytorch_fallback, + _int8_gemm_triton_or_fallback +) +# set TRITON_SKIP_AUTOTUNING=1 to skip autotuning +os.environ['TRITON_SKIP_AUTOTUNING'] = '1' class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" @@ -186,5 +196,2516 @@ def test_unsupported_op_dequantizes(self): self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") +class TestBlockWiseINT8Layout(unittest.TestCase): + """Test the BlockWiseINT8Layout implementation""" + + def test_weight_quantize_dequantize(self): + """Test weight quantization and dequantization""" + # Create a weight tensor (M, N) with dimensions divisible by 128 + weight = torch.randn(256, 512, dtype=torch.float32) + block_size = 128 + + # Quantize as weight + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Check quantized data + self.assertEqual(qdata.dtype, torch.int8) + self.assertEqual(qdata.shape, weight.shape) + + # Check scale shape: (M//block_size, N//block_size) + expected_scale_shape = (256 // block_size, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + self.assertEqual(layout_params['block_size'], block_size) + self.assertTrue(layout_params['is_weight']) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + # Dequantize + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Check reconstruction quality + self.assertEqual(dequantized.dtype, torch.float32) + self.assertEqual(dequantized.shape, weight.shape) + + # INT8 has limited precision, so we use a relaxed tolerance + max_error = (dequantized - weight).abs().max() + mean_error = (dequantized - weight).abs().mean() + self.assertLess(mean_error, 0.1) # Mean error should be reasonable for INT8 + + def test_activation_quantize_dequantize(self): + """Test activation quantization and dequantization""" + # Create an activation tensor with batch dimensions + activation = torch.randn(4, 16, 512, dtype=torch.float32) + block_size = 128 + + # Quantize as activation + qdata, layout_params = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Check quantized data + self.assertEqual(qdata.dtype, torch.int8) + self.assertEqual(qdata.shape, activation.shape) + + # Check scale shape: (*batch_dims, K//block_size) + expected_scale_shape = (4, 16, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + self.assertEqual(layout_params['block_size'], block_size) + self.assertFalse(layout_params['is_weight']) + + # Dequantize + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Check reconstruction + self.assertEqual(dequantized.shape, activation.shape) + mean_error = (dequantized - activation).abs().mean() + self.assertLess(mean_error, 0.1) + + def test_quantized_tensor_creation(self): + """Test creating QuantizedTensor with BlockWiseINT8Layout""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.int8) + self.assertEqual(qt.shape, weight.shape) + self.assertEqual(qt._layout_type, "BlockWiseINT8Layout") + + # Test dequantization + dequantized = qt.dequantize() + self.assertEqual(dequantized.dtype, torch.float32) + mean_error = (dequantized - weight).abs().mean() + self.assertLess(mean_error, 0.1) + + +class TestBlockWiseINT8Operations(unittest.TestCase): + """Test operations with BlockWiseINT8 quantized tensors""" + + def test_linear_operation(self): + """Test linear operation with quantized weight and activation""" + torch.manual_seed(42) + + # Create test data + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Input activation + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32) + + # Weight (note: linear expects weight as (out_features, in_features)) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32) + bias = torch.randn(out_features, dtype=torch.float32) + + # Quantize both + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Compute quantized linear + output_q = torch.nn.functional.linear(input_q, weight_q, bias) + + # Compute reference (full precision) + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare results + self.assertEqual(output_q.shape, output_ref.shape) + + # INT8 quantization introduces error, but should be reasonable + mean_rel_error = ((output_q - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.2) # 20% relative error tolerance + + def test_clone_operation(self): + """Test clone operation on INT8 quantized tensor""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, "BlockWiseINT8Layout") + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_detach_operation(self): + """Test detach operation on INT8 quantized tensor""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, "BlockWiseINT8Layout") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_device_transfer(self): + """Test moving INT8 quantized tensor to different devices""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + # Move to CPU (should be no-op if already on CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + def test_mixed_precision_fallback(self): + """Test mixed precision: quantized weight with float input""" + torch.manual_seed(42) + + input_fp32 = torch.randn(4, 256, dtype=torch.float32) + weight_fp32 = torch.randn(512, 256, dtype=torch.float32) + + # Only quantize weight + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + # Linear with float input and quantized weight + output = torch.nn.functional.linear(input_fp32, weight_q) + + # Should work via fallback + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32) + + # With mixed precision fallback (dequantize weight), error should be small + mean_error = (output - output_ref).abs().mean() + self.assertLess(mean_error, 0.3) + + +class TestBlockWiseINT8EdgeCases(unittest.TestCase): + """Test edge cases and error handling for INT8 quantization""" + + def test_dimension_alignment(self): + """Test that dimensions must be divisible by block_size""" + # Try to quantize with misaligned dimensions + weight = torch.randn(200, 300, dtype=torch.float32) # Not divisible by 128 + + with self.assertRaises(AssertionError): + BlockWiseINT8Layout.quantize(weight, block_size=128, is_weight=True) + + def test_weight_must_be_2d(self): + """Test that weight quantization requires 2D tensors""" + weight_3d = torch.randn(4, 256, 512, dtype=torch.float32) + + with self.assertRaises(AssertionError): + BlockWiseINT8Layout.quantize(weight_3d, block_size=128, is_weight=True) + + def test_different_block_sizes(self): + """Test quantization with different block sizes""" + for block_size in [64, 128, 256]: + weight = torch.randn(512, 512, dtype=torch.float32) + + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + expected_scale_shape = (512 // block_size, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + + # Verify dequantization works + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + self.assertEqual(dequantized.shape, weight.shape) + + +class TestBlockWiseINT8Precision(unittest.TestCase): + """Precision tests for BlockWiseINT8Layout operations""" + + def test_weight_quantization_matches_manual_calculation(self): + """Test that weight quantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + M, N = 256, 512 + block_size = 128 + weight = torch.randn(M, N, dtype=torch.float32, device=device) + + # Manual PyTorch calculation for weight quantization + # Weight shape: (M, N), blocks: (M//block_size, N//block_size) + weight_reshaped = weight.reshape(M // block_size, block_size, N // block_size, block_size) + weight_blocks = weight_reshaped.permute(0, 2, 1, 3) # (M//bs, N//bs, bs, bs) + + # Calculate scale per block: amax / 127.0 + amax = weight_blocks.abs().amax(dim=(2, 3), keepdim=False) # (M//bs, N//bs) + scale_manual = amax / 127.0 + scale_manual = torch.maximum(scale_manual, torch.tensor(1e-8, device=device, dtype=weight.dtype)) + + # Quantize: divide by scale and clamp to [-127, 127] + weight_blocks_scaled = weight_blocks / scale_manual.unsqueeze(-1).unsqueeze(-1) + int8_manual = torch.clamp(weight_blocks_scaled, -127.0, 127.0).to(torch.int8) + int8_manual = int8_manual.permute(0, 2, 1, 3).reshape(M, N) + + # Use BlockWiseINT8Layout.quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Compare int8 values + self.assertEqual(qdata.shape, int8_manual.shape) + self.assertEqual(qdata.dtype, torch.int8) + matches = (qdata == int8_manual).float().mean().item() + self.assertGreater(matches, 0.95, f"Only {matches*100:.2f}% of int8 values match") + + # Compare scales + self.assertEqual(layout_params['scale'].shape, scale_manual.shape) + scale_diff = (layout_params['scale'] - scale_manual).abs().mean().item() + scale_rel_diff = (scale_diff / (scale_manual.abs().mean().item() + 1e-8)) + self.assertLess(scale_rel_diff, 0.01, f"Scale relative difference too high: {scale_rel_diff}") + + def test_activation_quantization_matches_manual_calculation(self): + """Test that activation quantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + K = 512 + block_size = 128 + activation = torch.randn(batch_size, seq_len, K, dtype=torch.float32, device=device) + + # Manual PyTorch calculation for activation quantization + # Activation shape: (*batch_dims, K), scale shape: (*batch_dims, K//block_size) + orig_shape = activation.shape + batch_dims = orig_shape[:-1] + + # Reshape to expose blocks in last dimension + activation_reshaped = activation.reshape(*batch_dims, K // block_size, block_size) + + # Calculate scale per block: amax / 127.0 + amax = activation_reshaped.abs().amax(dim=-1, keepdim=False) # (*batch_dims, K//block_size) + scale_manual = amax / 127.0 + scale_manual = torch.maximum(scale_manual, torch.tensor(1e-8, device=device, dtype=activation.dtype)) + + # Quantize: divide by scale and clamp to [-127, 127] + activation_scaled = activation_reshaped / scale_manual.unsqueeze(-1) + int8_manual = torch.clamp(activation_scaled, -127.0, 127.0).to(torch.int8) + int8_manual = int8_manual.reshape(orig_shape) + + # Use BlockWiseINT8Layout.quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Compare int8 values + self.assertEqual(qdata.shape, int8_manual.shape) + self.assertEqual(qdata.dtype, torch.int8) + matches = (qdata == int8_manual).float().mean().item() + self.assertGreater(matches, 0.95, f"Only {matches*100:.2f}% of int8 values match") + + # Compare scales + self.assertEqual(layout_params['scale'].shape, scale_manual.shape) + scale_diff = (layout_params['scale'] - scale_manual).abs().mean().item() + scale_rel_diff = (scale_diff / (scale_manual.abs().mean().item() + 1e-8)) + self.assertLess(scale_rel_diff, 0.01, f"Scale relative difference too high: {scale_rel_diff}") + + def test_dequantization_matches_manual_calculation(self): + """Test that dequantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + # Test weight dequantization + M, N = 256, 512 + block_size = 128 + weight = torch.randn(M, N, dtype=torch.float32, device=device) + + # Quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Manual dequantization for weight + scale = layout_params['scale'] # (M//bs, N//bs) + int8_data = qdata # (M, N) + orig_dtype = layout_params['orig_dtype'] + + # Reshape to blocks + int8_reshaped = int8_data.reshape(M // block_size, block_size, N // block_size, block_size) + int8_blocks = int8_reshaped.permute(0, 2, 1, 3) # (M//bs, N//bs, bs, bs) + + # Dequantize: int8 * scale (no division by 127) + fp_blocks = int8_blocks.to(orig_dtype) * scale.unsqueeze(-1).unsqueeze(-1) + dequant_manual = fp_blocks.permute(0, 2, 1, 3).reshape(M, N) + + # Use BlockWiseINT8Layout.dequantize + dequant_layout = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Compare + diff = (dequant_layout - dequant_manual).abs().max().item() + self.assertLess(diff, 1e-5, f"Dequantization differs by {diff}") + + # Test activation dequantization + batch_size = 4 + seq_len = 16 + K = 512 + activation = torch.randn(batch_size, seq_len, K, dtype=torch.float32, device=device) + + qdata_act, layout_params_act = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Manual dequantization for activation + scale_act = layout_params_act['scale'] # (batch_size, seq_len, K//bs) + int8_data_act = qdata_act # (batch_size, seq_len, K) + orig_dtype_act = layout_params_act['orig_dtype'] + + # Reshape + int8_reshaped_act = int8_data_act.reshape(batch_size, seq_len, K // block_size, block_size) + # Dequantize: int8 * scale (no division by 127) + fp_blocks_act = int8_reshaped_act.to(orig_dtype_act) * scale_act.unsqueeze(-1) + dequant_manual_act = fp_blocks_act.reshape(batch_size, seq_len, K) + + # Use BlockWiseINT8Layout.dequantize + dequant_layout_act = BlockWiseINT8Layout.dequantize(qdata_act, **layout_params_act) + + # Compare + diff_act = (dequant_layout_act - dequant_manual_act).abs().max().item() + self.assertLess(diff_act, 1e-5, f"Activation dequantization differs by {diff_act}") + + def test_triton_linear_matches_pytorch_fallback(self): + """Test that Triton kernel INT8 GEMM matches PyTorch INT8 GEMM fallback""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 1024 + block_size = 128 + + # Create original float tensors + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize to get int8 data and scales + input_q = QuantizedTensor.from_float( + input_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Extract int8 data and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Call Triton/fallback version (will use Triton on GPU if available) + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Call PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Convert both to float32 for fair comparison (Triton outputs float16, PyTorch outputs float32) + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + + # These should match very closely (same int8 inputs, same computation) + abs_diff = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff = abs_diff.mean().item() + max_abs_diff = abs_diff.max().item() + + # Use relative error to account for float16 precision limits + rel_diff = abs_diff / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff = rel_diff.mean().item() + + # Since both compute the same INT8 GEMM from same inputs, differences should be tiny + self.assertLess(mean_rel_diff, 1e-3, + f"Triton and PyTorch INT8 GEMM differ too much: mean_rel={mean_rel_diff:.6f}, mean_abs={mean_abs_diff:.6f}, max={max_abs_diff:.6f}") + + def test_triton_linear_from_raw_int8_and_scales(self): + """Test INT8 GEMM from manually created int8 data and scales - compare 3 methods""" + device = torch.device('cuda' if has_gpu() else 'cpu') + if not has_gpu(): + self.skipTest("This test requires GPU (Triton kernels)") + torch.manual_seed(123) + + batch_size = 2 + seq_len = 8 + in_features = 256 + out_features = 512 + block_size = 128 + + # Manually create int8 data and scales for input (activation) + # Input shape: (batch_size, seq_len, in_features) + input_int8 = torch.randint(-127, 127, (batch_size, seq_len, in_features), + dtype=torch.int8, device=device) + input_scale = torch.rand(batch_size, seq_len, in_features // block_size, + dtype=torch.float32, device=device) * 0.1 + + input_layout_params = { + 'scale': input_scale, + 'block_size': block_size, + 'is_weight': False, + 'orig_dtype': torch.float32 + } + input_q = QuantizedTensor(input_int8, "BlockWiseINT8Layout", input_layout_params) + + # Manually create int8 data and scales for weight + # Weight shape: (out_features, in_features) + weight_int8 = torch.randint(-127, 127, (out_features, in_features), + dtype=torch.int8, device=device) + weight_scale = torch.rand(out_features // block_size, in_features // block_size, + dtype=torch.float32, device=device) * 0.1 + + weight_layout_params = { + 'scale': weight_scale, + 'block_size': block_size, + 'is_weight': True, + 'orig_dtype': torch.float32 + } + weight_q = QuantizedTensor(weight_int8, "BlockWiseINT8Layout", weight_layout_params) + + # Bias + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Method 1: Call INT8 GEMM via Triton/fallback + output_triton = _int8_gemm_triton_or_fallback( + input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias, out_quant=False + ) + + # Method 2: Call PyTorch INT8 GEMM fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias + ) + + # Method 3: Dequantize and use standard torch.nn.functional.linear + input_dequant = input_q.dequantize() + weight_dequant = weight_q.dequantize() + output_dequant = torch.nn.functional.linear(input_dequant, weight_dequant, bias) + + # Convert all to float32 for fair comparison + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + output_dequant_fp32 = output_dequant.to(torch.float32) + + # Compare Method 1 vs Method 2: Triton vs PyTorch INT8 GEMM + self.assertEqual(output_triton.shape, output_pytorch.shape) + abs_diff_12 = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff_12 = abs_diff_12.mean().item() + max_abs_diff_12 = abs_diff_12.max().item() + + # Use relative error since Triton outputs float16 which has limited precision for large values + rel_diff_12 = abs_diff_12 / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff_12 = rel_diff_12.mean().item() + + # Same int8 data → both INT8 GEMMs should produce nearly identical results + # Use 0.1% relative error tolerance to account for float16 precision limits + self.assertLess(mean_rel_diff_12, 1e-3, + f"Triton and PyTorch INT8 GEMM differ: mean_rel={mean_rel_diff_12:.6f}, mean_abs={mean_abs_diff_12:.6f}, max_abs={max_abs_diff_12:.6f}") + + # Compare Method 1 vs Method 3: Triton INT8 GEMM vs Dequant+Float Linear + self.assertEqual(output_triton.shape, output_dequant.shape) + abs_diff_13 = (output_triton_fp32 - output_dequant_fp32).abs() + mean_abs_diff_13 = abs_diff_13.mean().item() + max_abs_diff_13 = abs_diff_13.max().item() + + # Use relative error for float16 precision limits + rel_diff_13 = abs_diff_13 / (output_dequant_fp32.abs() + 1e-6) + mean_rel_diff_13 = rel_diff_13.mean().item() + + # INT8 GEMM should match dequant+float linear (both compute the same thing) + self.assertLess(mean_rel_diff_13, 1e-3, + f"Triton INT8 GEMM and dequant+float differ: mean_rel={mean_rel_diff_13:.6f}, mean_abs={mean_abs_diff_13:.6f}, max_abs={max_abs_diff_13:.6f}") + + # Compare Method 2 vs Method 3: PyTorch INT8 GEMM vs Dequant+Float Linear + abs_diff_23 = (output_pytorch_fp32 - output_dequant_fp32).abs() + mean_abs_diff_23 = abs_diff_23.mean().item() + max_abs_diff_23 = abs_diff_23.max().item() + + # Use relative error + rel_diff_23 = abs_diff_23 / (output_dequant_fp32.abs() + 1e-6) + mean_rel_diff_23 = rel_diff_23.mean().item() + + # PyTorch INT8 GEMM should also match dequant+float linear + self.assertLess(mean_rel_diff_23, 1e-3, + f"PyTorch INT8 GEMM and dequant+float differ: mean_rel={mean_rel_diff_23:.6f}, mean_abs={mean_abs_diff_23:.6f}, max_abs={max_abs_diff_23:.6f}") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_triton_vs_pytorch_linear_implementation(self): + """Compare Triton kernel vs PyTorch fallback implementation directly""" + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 8 + seq_len = 32 + in_features = 1024 + out_features = 2048 + block_size = 128 + + # Create test data + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize + input_q = QuantizedTensor.from_float(input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False) + weight_q = QuantizedTensor.from_float(weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True) + + # Extract quantized data + a_int8, a_scale, a_block_size, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, b_block_size, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Call Triton version (via _int8_gemm_triton_or_fallback) + # Note: This may still use Triton for quant fusion even with out_quant=False + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Call PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Compare Triton vs PyTorch fallback implementations + triton_pytorch_diff = (output_triton - output_pytorch).abs().mean().item() + + # These should match very closely since both compute the same operation + self.assertLess(triton_pytorch_diff, 1e-2, + f"Triton and PyTorch implementations differ: {triton_pytorch_diff}") + + # Also test via high-level API (which may return quantized output) + output_api = torch.nn.functional.linear(input_q, weight_q, bias) + if isinstance(output_api, QuantizedTensor): + output_api_dequant = output_api.dequantize() + else: + output_api_dequant = output_api + + # Compare API with PyTorch fallback (more lenient since API might use different path) + api_pytorch_diff = (output_api_dequant - output_pytorch).abs().mean().item() + self.assertLess(api_pytorch_diff, 0.5, + f"API and PyTorch implementations differ: {api_pytorch_diff}") + + def test_int8_gemm_with_block_size_128(self): + """Test INT8 GEMM with block_size=128 (standard size for Triton kernels)""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 512 + block_size = 128 + + # Create test data + input_fp = torch.randn(batch_size, seq_len, in_features, + dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, + dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize to get int8 data + input_q = QuantizedTensor.from_float( + input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False + ) + weight_q = QuantizedTensor.from_float( + weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True + ) + + # Extract int8 and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run Triton/fallback INT8 GEMM + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Run PyTorch INT8 GEMM fallback + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Convert both to float32 for fair comparison (Triton outputs float16, PyTorch outputs float32) + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + + # Compare using relative error + abs_diff = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff = abs_diff.mean().item() + rel_diff = abs_diff / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff = rel_diff.mean().item() + + self.assertLess(mean_rel_diff, 1e-3, + f"Triton and PyTorch INT8 GEMM differ: mean_rel={mean_rel_diff:.6f}, mean_abs={mean_abs_diff:.6f}") + + def test_end_to_end_quantization_accuracy(self): + """Test end-to-end: quantize → INT8 GEMM → output accuracy vs float baseline""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 1024 + block_size = 128 + + # Create float tensors + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Float baseline + output_float = torch.nn.functional.linear(input_fp, weight_fp, bias) + + # Quantize → INT8 GEMM path + input_q = QuantizedTensor.from_float(input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False) + weight_q = QuantizedTensor.from_float(weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True) + + # Get int8 data and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run INT8 GEMM + output_int8 = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Convert to float32 for fair comparison (Triton outputs float16) + output_int8_fp32 = output_int8.to(torch.float32) + output_float_fp32 = output_float.to(torch.float32) + + # Compare with float baseline + abs_error = (output_int8_fp32 - output_float_fp32).abs() + mean_abs_error = abs_error.mean().item() + rel_error = abs_error / (output_float_fp32.abs() + 1e-6) + mean_rel_error = rel_error.mean().item() + + # This error is from quantization, not from INT8 GEMM implementation + # INT8 quantization can have ~5-20% relative error depending on data distribution + self.assertLess(mean_rel_error, 0.25, + f"Quantization error too high: {mean_rel_error:.4f}") + + def test_basic_weight_quantization(self): + """Test basic weight quantization precision""" + device = torch.device('cuda' if has_gpu() else 'cpu') + weight = torch.randn(256, 512, dtype=torch.float32, device=device) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + self.assertEqual(qt.shape, weight.shape) + self.assertEqual(qt.dtype, torch.int8) + + dequantized = qt.dequantize() + error = (dequantized - weight).abs().mean() + self.assertLess(error, 0.1, "Mean reconstruction error too high") + + def test_large_activation_quantization(self): + """Test activation quantization with larger tensor""" + device = torch.device('cuda' if has_gpu() else 'cpu') + activation = torch.randn(16, 128, 4096, dtype=torch.float32, device=device) + + qt = QuantizedTensor.from_float( + activation, + "BlockWiseINT8Layout", + block_size=128, + is_weight=False + ) + + self.assertEqual(qt.shape, activation.shape) + self.assertEqual(qt.dtype, torch.int8) + + dequantized = qt.dequantize() + error = (dequantized - activation).abs().mean() + self.assertLess(error, 0.1, "Mean reconstruction error too high") + + def test_quantized_linear_precision(self): + """Test quantized linear operation precision""" + torch.manual_seed(42) + device = torch.device('cuda' if has_gpu() else 'cpu') + + batch_size = 16 + seq_len = 128 + in_features = 2048 + out_features = 2048 + block_size = 128 + + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize both + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Compute quantized linear (returns QuantizedTensor by default) + output_q = torch.nn.functional.linear(input_q, weight_q, bias) + output_q = QuantizedTensor.from_float(output_q, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + self.assertIsInstance(output_q, QuantizedTensor, "Default output should be QuantizedTensor") + + # Dequantize for comparison + output_dequant = output_q.dequantize() + + # Compute reference + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + self.assertEqual(output_dequant.shape, output_ref.shape) + + mean_rel_error = ((output_dequant - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.2, "Mean relative error too high") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_triton_vs_pytorch_precision(self): + """Compare Triton kernel vs PyTorch fallback precision""" + # Check if Triton is available + try: + from comfy.int8_kernels import int8_gemm as triton_int8_gemm + has_triton = True + except ImportError: + self.skipTest("Triton kernels not available") + + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Extract quantized data + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run Triton version (via _int8_gemm_triton_or_fallback) + output_triton = _int8_gemm_triton_or_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # Run PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # Compute reference + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare errors + error_triton = ((output_triton - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_pytorch = ((output_pytorch - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_between = (output_triton - output_pytorch).abs().mean() + + self.assertLess(error_triton, 0.2, "Triton error too high") + self.assertLess(error_pytorch, 0.2, "PyTorch error too high") + self.assertLess(error_between, 4e-3, "Triton and PyTorch implementations differ") + + # Test via high-level API (torch dispatch) + output_dispatch = torch.nn.functional.linear(input_q, weight_q, bias) + + # Dequantize if needed + if isinstance(output_dispatch, QuantizedTensor): + output_dispatch_fp32 = output_dispatch.dequantize() + else: + output_dispatch_fp32 = output_dispatch + + # Compare with reference + error_dispatch = ((output_dispatch_fp32 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(error_dispatch, 0.2, "Torch dispatch error too high") + + # Compare dispatch output with low-level Triton output + error_dispatch_vs_triton = (output_dispatch_fp32 - output_triton).abs().mean() + self.assertLess(error_dispatch_vs_triton, 0.2, "Dispatch differs from low-level implementation") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_vs_fp8_precision(self): + """Compare INT8 vs FP8 precision""" + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + except (RuntimeError, AttributeError): + self.skipTest("FP8 dtypes not supported on this system") + + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 16 + seq_len = 128 + in_features = 2048 + out_features = 2048 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_int8 = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Quantize with FP8 + input_fp8 = QuantizedTensor.from_float( + input_fp32, + "TensorCoreFP8Layout", + dtype=torch.float8_e4m3fn + ) + + weight_fp8 = QuantizedTensor.from_float( + weight_fp32, + "TensorCoreFP8Layout", + dtype=torch.float8_e4m3fn + ) + + # Compute outputs + output_int8_q = torch.nn.functional.linear(input_int8, weight_int8, bias) + output_int8 = output_int8_q.dequantize() if isinstance(output_int8_q, QuantizedTensor) else output_int8_q + + # FP8 doesn't support fused bias, so add it manually + output_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + output_fp8 = output_fp8 + bias + if isinstance(output_fp8, QuantizedTensor): + output_fp8 = output_fp8.dequantize() + + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare precision + error_int8 = ((output_int8 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_fp8 = ((output_fp8 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_between = (output_int8 - output_fp8).abs().mean() + + self.assertLess(error_int8, 0.2, "INT8 error too high") + self.assertLess(error_fp8, 0.4, "FP8 error too high") + + # Memory usage comparison + int8_memory = input_int8._qdata.element_size() * input_int8._qdata.numel() + \ + weight_int8._qdata.element_size() * weight_int8._qdata.numel() + fp8_memory = input_fp8._qdata.element_size() * input_fp8._qdata.numel() + \ + weight_fp8._qdata.element_size() * weight_fp8._qdata.numel() + fp32_memory = input_fp32.element_size() * input_fp32.numel() + \ + weight_fp32.element_size() * weight_fp32.numel() + + self.assertLess(int8_memory, fp32_memory, "INT8 should use less memory than FP32") + self.assertLess(fp8_memory, fp32_memory, "FP8 should use less memory than FP32") + + def test_output_types(self): + """Test output types for all registered operations""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + weight_int8 = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Test 1: linear with quantized output (default) + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + output = QuantizedTensor.from_float(output, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output, QuantizedTensor, "Default output should be QuantizedTensor") + self.assertEqual(output.layout_type, "BlockWiseINT8Layout") + + # Test 2: linear with explicit dequantization + output_q = torch.nn.functional.linear(input_int8, weight_int8, bias) + output_reg = output_q.dequantize() + self.assertNotIsInstance(output_reg, QuantizedTensor, "Dequantized output should be regular tensor") + + # Test 3: mm operation (2D input) - default quantized output + input_2d = input_fp32.reshape(-1, in_features) + input_int8_2d = QuantizedTensor.from_float(input_2d, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8_t = weight_int8.t() + + output_mm = torch.mm(input_int8_2d, weight_int8_t) + output_mm = QuantizedTensor.from_float(output_mm, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_mm, QuantizedTensor, "Default mm output should be QuantizedTensor") + self.assertEqual(output_mm.layout_type, "BlockWiseINT8Layout") + + # Test 4: addmm operation - default quantized output + output_addmm = torch.addmm(bias, input_int8_2d, weight_int8_t) + output_addmm = QuantizedTensor.from_float(output_addmm, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_addmm, QuantizedTensor, "Default addmm output should be QuantizedTensor") + self.assertEqual(output_addmm.layout_type, "BlockWiseINT8Layout") + + # Test 5: view operation preserves quantization + view_result = input_int8.view(batch_size * seq_len, in_features) + self.assertIsInstance(view_result, QuantizedTensor, "view should preserve QuantizedTensor") + self.assertEqual(view_result.layout_type, "BlockWiseINT8Layout") + + # Test 6: transpose operation preserves quantization + transpose_result = weight_int8.t() + self.assertIsInstance(transpose_result, QuantizedTensor, "transpose should preserve QuantizedTensor") + self.assertEqual(transpose_result.layout_type, "BlockWiseINT8Layout") + + # Test 7: clone operation preserves quantization + clone_result = input_int8.clone() + self.assertIsInstance(clone_result, QuantizedTensor, "clone should preserve QuantizedTensor") + self.assertEqual(clone_result.layout_type, "BlockWiseINT8Layout") + + # Test 8: detach operation preserves quantization + detach_result = input_int8.detach() + self.assertIsInstance(detach_result, QuantizedTensor, "detach should preserve QuantizedTensor") + self.assertEqual(detach_result.layout_type, "BlockWiseINT8Layout") + + +class TestBlockWiseINT8GELU(unittest.TestCase): + """Test INT8 block-wise GELU activation""" + + def test_int8_gelu_basic(self): + """Test basic GELU operation with INT8 quantized tensors""" + device = torch.device('cuda' if has_gpu() else 'cpu') + + batch_size = 2 + seq_len = 512 + hidden_dim = 2048 + block_size = 128 + + # Create random input tensor + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float16, device=device) + + # Compute reference output (full precision) + with torch.no_grad(): + reference_output = torch.nn.functional.gelu(x) + + # Quantize input + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + # Apply GELU (should use fused kernel) + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + self.assertEqual(output_fp.shape, reference_output.shape) + + # Compute error metrics + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high: {relative_error}") + + def test_int8_gelu_2d(self): + """Test GELU with 2D tensors""" + device = torch.device('cuda' if has_gpu() else 'cpu') + + M, N = 256, 2048 + block_size = 128 + + x = torch.randn(M, N, dtype=torch.float16, device=device) + reference_output = torch.nn.functional.gelu(x) + + # Quantize and apply GELU + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high: {relative_error}") + + def test_int8_gelu_different_shapes(self): + """Test GELU with various tensor shapes""" + device = torch.device('cuda' if has_gpu() else 'cpu') + block_size = 128 + + test_shapes = [ + (128, 1024), # 2D + (4, 512, 2048), # 3D + (2, 8, 128, 1024), # 4D + ] + + for shape in test_shapes: + with self.subTest(shape=shape): + x = torch.randn(*shape, dtype=torch.float16, device=device) + reference_output = torch.nn.functional.gelu(x) + + # Quantize and apply GELU + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high for shape {shape}: {relative_error}") + + +class TestBlockWiseINT8QuantFusion(unittest.TestCase): + """Test fused INT8 matmul + quantization kernels""" + + @unittest.skip("out_quant parameter not yet implemented in torch ops") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_linear_with_out_quant(self): + """Test INT8 linear operation with fused output quantization""" + batch_size = 4 + seq_len = 256 + input_dim = 1024 + output_dim = 2048 + block_size = 128 + + # Create input tensor + input_fp = torch.randn(batch_size, seq_len, input_dim, dtype=torch.float16, device='cuda') + weight_fp = torch.randn(output_dim, input_dim, dtype=torch.float16, device='cuda') + bias = torch.randn(output_dim, dtype=torch.float16, device='cuda') + + # Quantize input and weight + input_q = QuantizedTensor.from_float( + input_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Test 1: Regular linear (float output) + output_float = torch.ops.aten.linear.default(input_q, weight_q, bias) + self.assertIsNotNone(output_float) + self.assertEqual(output_float.shape, (batch_size, seq_len, output_dim)) + + # Test 2: Linear with fused output quantization (out_quant=True) + output_quant = torch.ops.aten.linear.default( + input_q, weight_q, bias + ) + + self.assertIsInstance(output_quant, QuantizedTensor, "Output should be QuantizedTensor when out_quant=True") + self.assertEqual(output_quant._layout_type, "BlockWiseINT8Layout") + + # Verify scale shape matches activation format + expected_scale_shape = (batch_size, seq_len, output_dim // block_size) + actual_scale_shape = output_quant._layout_params['scale'].shape + self.assertEqual(actual_scale_shape, expected_scale_shape, "Scale shape should match activation format") + + # Dequantize and compare + output_dequant = output_quant.dequantize() + self.assertEqual(output_dequant.shape, (batch_size, seq_len, output_dim)) + + # Compare with float output + diff = (output_float - output_dequant).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + relative_error = (diff / (output_float.abs() + 1e-6)).mean().item() + + self.assertLess(relative_error, 0.15, f"Relative error too high: {relative_error}") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_addmm_with_out_quant(self): + """Test INT8 addmm operation with fused output quantization""" + M, K, N = 512, 1024, 2048 + block_size = 128 + + # Create tensors + input_fp = torch.randn(M, K, dtype=torch.float16, device='cuda') + weight_fp = torch.randn(N, K, dtype=torch.float16, device='cuda') + bias = torch.randn(N, dtype=torch.float16, device='cuda') + + # Quantize + input_q = QuantizedTensor.from_float( + input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False + ) + weight_q = QuantizedTensor.from_float( + weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True + ) + + # Test with out_quant=True + output_quant = torch.ops.aten.addmm.default( + bias, input_q, weight_q.t() + ) + + output_quant = QuantizedTensor.from_float(output_quant, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_quant, QuantizedTensor, "Output should be QuantizedTensor when out_quant=True") + self.assertEqual(output_quant.shape, (M, N)) + self.assertEqual(output_quant._layout_type, "BlockWiseINT8Layout") + + # Verify it can be dequantized + output_dequant = output_quant.dequantize() + self.assertEqual(output_dequant.shape, (M, N)) + self.assertEqual(output_dequant.dtype, torch.float16) + + +# Benchmark tests (skipped by default) +class TestBlockWiseINT8Benchmarks(unittest.TestCase): + """Performance benchmark tests for BlockWiseINT8Layout""" + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_runtime_comparison(self): + """Benchmark INT8 quantized ops via torch dispatch (high-level API)""" + device = torch.device('cuda') + torch.manual_seed(42) + + # More comprehensive test configurations + test_configs = [ + {"name": "Tiny", "batch": 2, "seq": 8, "in_feat": 128, "out_feat": 256, "block": 64}, + {"name": "Small", "batch": 4, "seq": 16, "in_feat": 256, "out_feat": 512, "block": 128}, + {"name": "Medium", "batch": 8, "seq": 32, "in_feat": 512, "out_feat": 1024, "block": 128}, + {"name": "Large", "batch": 16, "seq": 64, "in_feat": 1024, "out_feat": 2048, "block": 128}, + {"name": "XL", "batch": 32, "seq": 128, "in_feat": 2048, "out_feat": 4096, "block": 128}, + {"name": "XXL", "batch": 64, "seq": 256, "in_feat": 4096, "out_feat": 4096, "block": 128}, + ] + + n_warmup = 10 + n_iters = 200 # More iterations for better averaging + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Check if Triton is available + try: + from comfy.int8_kernels import int8_gemm as triton_int8_gemm + print("✓ Using Triton INT8 kernels (optimized path)\n") + except ImportError: + print("⚠ Using PyTorch fallback (Triton not available)\n") + + results = [] + + for config in test_configs: + name = config["name"] + batch_size = config["batch"] + seq_len = config["seq"] + in_features = config["in_feat"] + out_features = config["out_feat"] + block_size = config["block"] + + print(f"{name}: batch={batch_size}, seq={seq_len}, in={in_features}, out={out_features}, block={block_size}") + + # Calculate FLOPS for this configuration + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k # 2 for multiply-add + + try: + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize using high-level API + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up - test full dispatch path + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Benchmark INT8 via torch dispatch (includes dispatch overhead + quantized output) + int8_times = [] + for _ in range(n_iters): + start = time.time() + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Also benchmark with dequantization to FP32 output (more realistic for some use cases) + int8_dequant_times = [] + for _ in range(n_iters): + start = time.time() + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + if isinstance(output, QuantizedTensor): + output = output.dequantize() + torch.cuda.synchronize() + int8_dequant_times.append((time.time() - start) * 1000) + + # Benchmark FP32 reference + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + # Convert to torch tensors for statistics + int8_times = torch.tensor(int8_times) + int8_dequant_times = torch.tensor(int8_dequant_times) + fp32_times = torch.tensor(fp32_times) + + # Calculate statistics + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + int8_dequant_mean = int8_dequant_times.mean().item() + int8_dequant_std = int8_dequant_times.std().item() + int8_dequant_min = int8_dequant_times.min().item() + + fp32_mean = fp32_times.mean().item() + fp32_std = fp32_times.std().item() + fp32_min = fp32_times.min().item() + + speedup_int8 = fp32_mean / int8_mean + speedup_int8_dequant = fp32_mean / int8_dequant_mean + + print(f" INT8 (quantized out): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (dequant out): {int8_dequant_mean:.3f}±{int8_dequant_std:.3f} ms (min: {int8_dequant_min:.3f} ms) [{flops/int8_dequant_mean/1e9:.2f} GFLOPS]") + print(f" FP32 reference: {fp32_mean:.3f}±{fp32_std:.3f} ms (min: {fp32_min:.3f} ms) [{flops/fp32_mean/1e9:.2f} GFLOPS]") + print(f" Speedup (INT8 quantized/FP32): {speedup_int8:.2f}x") + print(f" Speedup (INT8 dequant/FP32): {speedup_int8_dequant:.2f}x") + print(f" Dequant overhead: {((int8_dequant_mean - int8_mean) / int8_mean * 100):.1f}%\n") + + results.append({ + "name": name, + "int8_mean": int8_mean, + "int8_dequant_mean": int8_dequant_mean, + "fp32_mean": fp32_mean, + "speedup_int8": speedup_int8, + "speedup_int8_dequant": speedup_int8_dequant, + "flops": flops, + }) + + # Clean up memory after each configuration + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + if 'int8_times' in locals(): + del int8_times, int8_dequant_times, fp32_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration\n") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Print summary + print("\n" + "=" * 60) + print("Summary:") + print("=" * 60) + for result in results: + print(f"{result['name']:8s}: INT8 {result['int8_mean']:.3f}ms, " + f"INT8+dequant {result['int8_dequant_mean']:.3f}ms, " + f"FP32 {result['fp32_mean']:.3f}ms, " + f"Speedup: {result['speedup_int8']:.2f}x (quantized), {result['speedup_int8_dequant']:.2f}x (dequant)") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_vs_fp8_runtime(self): + """Benchmark INT8 vs FP8 runtime with comprehensive configs""" + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + has_fp8 = True + except (RuntimeError, AttributeError): + has_fp8 = False + + if not has_fp8: + print("⚠ FP8 dtypes not supported on this system, skipping comparison") + self.skipTest("FP8 not supported") + return + + device = torch.device('cuda') + torch.manual_seed(42) + + # More comprehensive test configurations + test_configs = [ + {"name": "Tiny", "batch": 2, "seq": 8, "in_feat": 128, "out_feat": 256, "block": 64}, + {"name": "Small", "batch": 4, "seq": 16, "in_feat": 256, "out_feat": 512, "block": 128}, + {"name": "Medium", "batch": 8, "seq": 32, "in_feat": 512, "out_feat": 1024, "block": 128}, + {"name": "Large", "batch": 16, "seq": 64, "in_feat": 1024, "out_feat": 2048, "block": 128}, + {"name": "XL", "batch": 32, "seq": 128, "in_feat": 2048, "out_feat": 4096, "block": 128}, + {"name": "XXL", "batch": 64, "seq": 256, "in_feat": 4096, "out_feat": 4096, "block": 128}, + {"name": "XXXL", "batch": 128, "seq": 512, "in_feat": 4096, "out_feat": 4096, "block": 128}, + ] + + n_warmup = 10 + n_iters = 200 # More iterations for better averaging + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}") + print("Note: INT8 uses fused bias, FP8 adds bias separately\n") + + results = [] + + for config in test_configs: + name = config["name"] + batch_size = config["batch"] + seq_len = config["seq"] + in_features = config["in_feat"] + out_features = config["out_feat"] + block_size = config["block"] + + print(f"{name}: batch={batch_size}, seq={seq_len}, in={in_features}, out={out_features}, block={block_size}") + + # Calculate FLOPS for this configuration + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k # 2 for multiply-add + + try: + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Quantize with FP8 + input_fp8 = QuantizedTensor.from_float(input_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + weight_fp8 = QuantizedTensor.from_float(weight_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + _ = out_fp8 + bias + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark INT8 (with fused bias) - collect all times + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Benchmark FP8 (bias added separately) + fp8_times = [] + for _ in range(n_iters): + start = time.time() + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + _ = out_fp8 + bias + torch.cuda.synchronize() + fp8_times.append((time.time() - start) * 1000) + + # Benchmark FP32 reference + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + # Convert to torch tensors for statistics + int8_times = torch.tensor(int8_times) + fp8_times = torch.tensor(fp8_times) + fp32_times = torch.tensor(fp32_times) + + # Calculate statistics + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + fp8_mean = fp8_times.mean().item() + fp8_std = fp8_times.std().item() + fp8_min = fp8_times.min().item() + + fp32_mean = fp32_times.mean().item() + fp32_std = fp32_times.std().item() + fp32_min = fp32_times.min().item() + + speedup_int8 = fp32_mean / int8_mean + speedup_fp8 = fp32_mean / fp8_mean + int8_vs_fp8 = fp8_mean / int8_mean + + print(f" INT8 (fused bias): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" FP8 (sep. bias): {fp8_mean:.3f}±{fp8_std:.3f} ms (min: {fp8_min:.3f} ms) [{flops/fp8_mean/1e9:.2f} GFLOPS]") + print(f" FP32 (fused bias): {fp32_mean:.3f}±{fp32_std:.3f} ms (min: {fp32_min:.3f} ms) [{flops/fp32_mean/1e9:.2f} GFLOPS]") + print(f" Speedup (INT8/FP32): {speedup_int8:.2f}x") + print(f" Speedup (FP8/FP32): {speedup_fp8:.2f}x") + + if int8_mean < fp8_mean: + print(f" ✓ INT8 is {int8_vs_fp8:.2f}x faster than FP8\n") + else: + print(f" ✓ FP8 is {1/int8_vs_fp8:.2f}x faster than INT8\n") + + results.append({ + "name": name, + "int8_mean": int8_mean, + "fp8_mean": fp8_mean, + "fp32_mean": fp32_mean, + "speedup_int8": speedup_int8, + "speedup_fp8": speedup_fp8, + "int8_vs_fp8": int8_vs_fp8, + "flops": flops, + }) + + # Clean up memory after each configuration + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + if has_fp8: + del input_fp8, weight_fp8 + if 'int8_times' in locals(): + del int8_times, fp8_times, fp32_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration\n") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Print summary + print("\n" + "=" * 60) + print("Summary:") + print("=" * 60) + for result in results: + print(f"{result['name']:8s}: INT8 {result['int8_mean']:.3f}ms, " + f"FP8 {result['fp8_mean']:.3f}ms, " + f"FP32 {result['fp32_mean']:.3f}ms, " + f"Speedup (INT8/FP32): {result['speedup_int8']:.2f}x, " + f"(FP8/FP32): {result['speedup_fp8']:.2f}x") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_quantization_dequantization_runtime(self): + """Benchmark quantization and dequantization operations""" + device = torch.device('cuda') + torch.manual_seed(42) + + n_warmup = 5 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Test configurations - various tensor sizes + test_configs = [ + {"name": "Small Weight", "shape": (512, 512), "is_weight": True, "block": 128}, + {"name": "Medium Weight", "shape": (2048, 2048), "is_weight": True, "block": 128}, + {"name": "Large Weight", "shape": (4096, 4096), "is_weight": True, "block": 128}, + {"name": "XL Weight", "shape": (8192, 8192), "is_weight": True, "block": 128}, + {"name": "Small Activation", "shape": (8, 64, 512), "is_weight": False, "block": 128}, + {"name": "Medium Activation", "shape": (16, 128, 2048), "is_weight": False, "block": 128}, + {"name": "Large Activation", "shape": (32, 256, 4096), "is_weight": False, "block": 128}, + {"name": "XL Activation", "shape": (64, 512, 4096), "is_weight": False, "block": 128}, + ] + + print("=" * 60) + print("INT8 BlockWise Quantization/Dequantization") + print("=" * 60) + + results_int8 = [] + + for config in test_configs: + name = config["name"] + shape = config["shape"] + is_weight = config["is_weight"] + block_size = config["block"] + + try: + # Create test tensor + tensor_fp32 = torch.randn(shape, dtype=torch.float32, device=device) + tensor_size_mb = tensor_fp32.numel() * tensor_fp32.element_size() / 1024 / 1024 + + print(f"\n{name}: shape={shape}, size={tensor_size_mb:.2f}MB") + + # Warm up + for _ in range(n_warmup): + qt = QuantizedTensor.from_float(tensor_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=is_weight) + _ = qt.dequantize() + torch.cuda.synchronize() + + # Benchmark quantization + quant_times = [] + for _ in range(n_iters): + start = time.time() + qt = QuantizedTensor.from_float(tensor_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=is_weight) + torch.cuda.synchronize() + quant_times.append((time.time() - start) * 1000) + + # Benchmark dequantization (reuse last quantized tensor) + dequant_times = [] + for _ in range(n_iters): + start = time.time() + _ = qt.dequantize() + torch.cuda.synchronize() + dequant_times.append((time.time() - start) * 1000) + + # Calculate statistics + quant_times = torch.tensor(quant_times) + dequant_times = torch.tensor(dequant_times) + + quant_mean = quant_times.mean().item() + quant_std = quant_times.std().item() + quant_min = quant_times.min().item() + + dequant_mean = dequant_times.mean().item() + dequant_std = dequant_times.std().item() + dequant_min = dequant_times.min().item() + + # Calculate throughput (GB/s) + quant_throughput = (tensor_size_mb / 1024) / (quant_mean / 1000) + dequant_throughput = (tensor_size_mb / 1024) / (dequant_mean / 1000) + + print(f" Quantization: {quant_mean:.3f}±{quant_std:.3f} ms (min: {quant_min:.3f} ms) [{quant_throughput:.2f} GB/s]") + print(f" Dequantization: {dequant_mean:.3f}±{dequant_std:.3f} ms (min: {dequant_min:.3f} ms) [{dequant_throughput:.2f} GB/s]") + print(f" Total roundtrip: {quant_mean + dequant_mean:.3f} ms") + + # Calculate memory savings + qt_memory = qt._qdata.element_size() * qt._qdata.numel() + qt_memory += qt._layout_params['scale'].element_size() * qt._layout_params['scale'].numel() + fp32_memory = tensor_fp32.element_size() * tensor_fp32.numel() + reduction = fp32_memory / qt_memory + + print(f" Memory: FP32 {fp32_memory/1024/1024:.2f}MB -> INT8 {qt_memory/1024/1024:.2f}MB ({reduction:.2f}x reduction)") + + results_int8.append({ + "name": name, + "shape": shape, + "size_mb": tensor_size_mb, + "quant_mean": quant_mean, + "dequant_mean": dequant_mean, + "quant_throughput": quant_throughput, + "dequant_throughput": dequant_throughput, + "reduction": reduction, + }) + + # Clean up memory after each configuration + del tensor_fp32, qt + if 'quant_times' in locals(): + del quant_times, dequant_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"\n{name}: ⚠ OOM - skipping") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Summary + print() + print("=" * 60) + print("Summary: INT8 Quantization/Dequantization Performance") + print("=" * 60) + for result in results_int8: + print(f"{result['name']:20s}: Quant {result['quant_mean']:.3f}ms, " + f"Dequant {result['dequant_mean']:.3f}ms, " + f"Total {result['quant_mean'] + result['dequant_mean']:.3f}ms") + + # Assertions for unittest + self.assertGreater(len(results_int8), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_fp16_vs_int8_real_model_sizes(self): + """Compare FP16 vs INT8 vs FP8 on actual model sizes via torch dispatch""" + device = torch.device('cuda') + torch.manual_seed(42) + + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + has_fp8 = True + print("✓ FP8 support detected") + except (RuntimeError, AttributeError): + has_fp8 = False + print("⚠ FP8 not supported on this system - will compare FP16 vs INT8 only") + + # Actual sizes from model dumps + test_configs = [ + # WAN 2.2 5B model sizes + { + "model": "WAN2.2-5B", + "name": "First layer (small batch)", + "input_shape": (2, 1, 3072), + "weight_shape": (18432, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "Attention layer (long seq)", + "input_shape": (2, 27280, 3072), + "weight_shape": (3072, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "MLP down projection (long seq)", + "input_shape": (2, 27280, 14336), + "weight_shape": (3072, 14336), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "MLP up projection (long seq)", + "input_shape": (2, 27280, 3072), + "weight_shape": (14336, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "Attention layer (medium seq)", + "input_shape": (2, 512, 3072), + "weight_shape": (3072, 3072), + "block_size": 128, + }, + # WAN 2.2 14B model sizes + { + "model": "WAN2.2-14B", + "name": "First layer (small batch)", + "input_shape": (2, 1, 5120), + "weight_shape": (30720, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "Attention layer (long seq)", + "input_shape": (2, 27280, 5120), + "weight_shape": (5120, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "Attention layer (medium seq)", + "input_shape": (2, 512, 5120), + "weight_shape": (5120, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "MLP up projection (long seq)", + "input_shape": (2, 27280, 5120), + "weight_shape": (13824, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "MLP down projection (long seq)", + "input_shape": (2, 27280, 13824), + "weight_shape": (5120, 13824), + "block_size": 128, + }, + ] + + n_warmup = 10 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + results = [] + current_model = None + + for config in test_configs: + model = config["model"] + name = config["name"] + input_shape = config["input_shape"] + weight_shape = config["weight_shape"] + block_size = config["block_size"] + + # Print model header when we switch models + if model != current_model: + print("\n" + "=" * 60) + print(f"{model} Model Layers") + print("=" * 60) + current_model = model + + print(f"\n{name}") + print(f" Input: {input_shape}, Weight: {weight_shape}") + + # Calculate FLOPS + batch, seq_len, in_features = input_shape + out_features, _ = weight_shape + m = batch * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + try: + # Measure initial VRAM + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + initial_vram = torch.cuda.memory_allocated() / 1024 / 1024 # MB + + # Create test data in FP16 and FP32 + input_fp32 = torch.randn(input_shape, dtype=torch.float32, device=device) + input_fp16 = input_fp32.to(torch.float16) + + weight_fp32 = torch.randn(weight_shape, dtype=torch.float32, device=device) + weight_fp16 = weight_fp32.to(torch.float16) + + bias_fp32 = torch.randn(out_features, dtype=torch.float32, device=device) + bias_fp16 = bias_fp32.to(torch.float16) + + # Measure FP16 VRAM + fp16_vram = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Quantize to INT8 + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Measure INT8 VRAM (after creating quantized tensors, before releasing FP16) + int8_vram_with_fp16 = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Quantize to FP8 if available + if has_fp8: + input_fp8 = QuantizedTensor.from_float(input_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + weight_fp8 = QuantizedTensor.from_float(weight_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + fp8_vram_with_others = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Calculate memory usage + fp16_input_mem = input_fp16.element_size() * input_fp16.numel() + fp16_weight_mem = weight_fp16.element_size() * weight_fp16.numel() + fp16_total_mem = fp16_input_mem + fp16_weight_mem + + int8_input_mem = input_int8._qdata.element_size() * input_int8._qdata.numel() + int8_input_mem += input_int8._layout_params['scale'].element_size() * input_int8._layout_params['scale'].numel() + + int8_weight_mem = weight_int8._qdata.element_size() * weight_int8._qdata.numel() + int8_weight_mem += weight_int8._layout_params['scale'].element_size() * weight_int8._layout_params['scale'].numel() + + int8_total_mem = int8_input_mem + int8_weight_mem + mem_reduction = fp16_total_mem / int8_total_mem + + print(f" Tensor Memory: FP16 {fp16_total_mem/1024/1024:.2f}MB -> INT8 {int8_total_mem/1024/1024:.2f}MB ({mem_reduction:.2f}x reduction)") + print(f" VRAM Usage: FP16 {fp16_vram:.2f}MB, INT8 {int8_vram_with_fp16:.2f}MB (incl. FP16 tensors)") + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + _ = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if has_fp8: + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + _ = out_fp8 + bias_fp32 + torch.cuda.synchronize() + + # Clear any warmup artifacts + torch.cuda.empty_cache() + + # Benchmark FP16 + fp16_times = [] + for _ in range(n_iters): + start = time.time() + output_fp16 = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + torch.cuda.synchronize() + fp16_times.append((time.time() - start) * 1000) + + # Benchmark INT8 (quantized output) + int8_times = [] + for _ in range(n_iters): + start = time.time() + output_int8 = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Benchmark INT8 with dequantization + int8_dequant_times = [] + for _ in range(n_iters): + start = time.time() + output_int8 = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if isinstance(output_int8, QuantizedTensor): + output_int8 = output_int8.dequantize() + torch.cuda.synchronize() + int8_dequant_times.append((time.time() - start) * 1000) + + # Benchmark FP8 if available + if has_fp8: + fp8_times = [] + for _ in range(n_iters): + start = time.time() + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + out_fp8 = out_fp8 + bias_fp32 + # Dequantize if needed + if isinstance(out_fp8, QuantizedTensor): + out_fp8 = out_fp8.dequantize() + torch.cuda.synchronize() + fp8_times.append((time.time() - start) * 1000) + + # Clear benchmark outputs to free memory + if 'output_fp16' in locals(): + del output_fp16 + if 'output_int8' in locals(): + del output_int8 + if has_fp8 and 'out_fp8' in locals(): + del out_fp8 + torch.cuda.empty_cache() + + # Calculate statistics + fp16_times = torch.tensor(fp16_times) + int8_times = torch.tensor(int8_times) + int8_dequant_times = torch.tensor(int8_dequant_times) + + fp16_mean = fp16_times.mean().item() + fp16_std = fp16_times.std().item() + fp16_min = fp16_times.min().item() + + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + int8_dequant_mean = int8_dequant_times.mean().item() + int8_dequant_std = int8_dequant_times.std().item() + int8_dequant_min = int8_dequant_times.min().item() + + speedup_int8 = fp16_mean / int8_mean + speedup_int8_dequant = fp16_mean / int8_dequant_mean + + print(f" FP16: {fp16_mean:.3f}±{fp16_std:.3f} ms (min: {fp16_min:.3f} ms) [{flops/fp16_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (quantized): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (dequantized): {int8_dequant_mean:.3f}±{int8_dequant_std:.3f} ms (min: {int8_dequant_min:.3f} ms) [{flops/int8_dequant_mean/1e9:.2f} GFLOPS]") + print(f" Speedup vs FP16: {speedup_int8:.2f}x (quantized), {speedup_int8_dequant:.2f}x (dequantized)") + + if has_fp8: + fp8_times = torch.tensor(fp8_times) + fp8_mean = fp8_times.mean().item() + fp8_std = fp8_times.std().item() + fp8_min = fp8_times.min().item() + speedup_fp8 = fp16_mean / fp8_mean + + print(f" FP8 (dequantized): {fp8_mean:.3f}±{fp8_std:.3f} ms (min: {fp8_min:.3f} ms) [{flops/fp8_mean/1e9:.2f} GFLOPS]") + print(f" Speedup vs FP16: {speedup_fp8:.2f}x") + else: + fp8_mean = None + speedup_fp8 = None + + # Precision check + output_fp16_check = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + output_int8_check = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if isinstance(output_int8_check, QuantizedTensor): + output_int8_check = output_int8_check.dequantize() + + # Convert FP16 output to FP32 for comparison + output_fp16_check_fp32 = output_fp16_check.to(torch.float32) + + # Compare INT8 vs FP16 (both in FP32 for fair comparison) + error_int8 = ((output_int8_check - output_fp16_check_fp32).abs() / (output_fp16_check_fp32.abs() + 1e-6)).mean() + print(f" Precision: INT8 vs FP16 mean relative error: {error_int8:.6f}") + + if has_fp8: + output_fp8_check = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + output_fp8_check = output_fp8_check + bias_fp32 + if isinstance(output_fp8_check, QuantizedTensor): + output_fp8_check = output_fp8_check.dequantize() + + error_fp8 = ((output_fp8_check - output_fp16_check_fp32).abs() / (output_fp16_check_fp32.abs() + 1e-6)).mean() + print(f" Precision: FP8 vs FP16 mean relative error: {error_fp8:.6f}") + else: + error_fp8 = None + + results.append({ + "model": model, + "name": name, + "input_shape": input_shape, + "weight_shape": weight_shape, + "fp16_mean": fp16_mean, + "int8_mean": int8_mean, + "int8_dequant_mean": int8_dequant_mean, + "fp8_mean": fp8_mean, + "speedup_int8": speedup_int8, + "speedup_int8_dequant": speedup_int8_dequant, + "speedup_fp8": speedup_fp8, + "mem_reduction": mem_reduction, + "error_int8": error_int8.item(), + "error_fp8": error_fp8.item() if error_fp8 is not None else None, + "fp16_vram": fp16_vram, + "int8_vram": int8_vram_with_fp16, + }) + + # Aggressive memory cleanup after each configuration to avoid OOM + # Delete input/weight tensors + del input_fp32, input_fp16, weight_fp32, weight_fp16, bias_fp32, bias_fp16 + del input_int8, weight_int8 + if has_fp8: + del input_fp8, weight_fp8 + + # Delete precision check outputs + if 'output_fp16_check' in locals(): + del output_fp16_check, output_fp16_check_fp32, output_int8_check + if has_fp8 and 'output_fp8_check' in locals(): + del output_fp8_check + + # Delete timing tensors + if 'fp16_times' in locals(): + del fp16_times, int8_times, int8_dequant_times + if has_fp8 and 'fp8_times' in locals(): + del fp8_times + + # Force Python garbage collection + gc.collect() + # Clear CUDA cache + torch.cuda.empty_cache() + # Synchronize to ensure cleanup is complete + torch.cuda.synchronize() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration") + # Ultra-aggressive cleanup on OOM + # Delete any lingering tensors from failed iteration + for var_name in list(locals().keys()): + if 'tensor' in var_name.lower() or var_name.endswith(('_fp16', '_fp32', '_int8', '_fp8')): + try: + del locals()[var_name] + except: + pass + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Summary table + print("\n" + "=" * 80) + if has_fp8: + print("Summary: FP16 vs INT8 vs FP8 Performance") + else: + print("Summary: FP16 vs INT8 Performance") + print("=" * 80) + + if results: + # Group results by model + models = {} + for result in results: + model = result["model"] + if model not in models: + models[model] = [] + models[model].append(result) + + # Print results grouped by model + for model_name, model_results in models.items(): + print(f"\n{model_name}:") + if has_fp8: + print(f"{'Layer':<25s} {'FP16':<10s} {'INT8':<10s} {'FP8':<10s} {'Speedup':<20s} {'Mem':<8s}") + else: + print(f"{'Layer':<30s} {'FP16 (ms)':<12s} {'INT8 (ms)':<12s} {'Speedup':<10s} {'Memory':<10s}") + print("-" * 80) + + for result in model_results: + layer_name = result["name"][:23] if has_fp8 else result["name"][:28] + if has_fp8 and result['fp8_mean'] is not None: + print(f"{layer_name:<25s} {result['fp16_mean']:>8.3f}ms {result['int8_dequant_mean']:>8.3f}ms {result['fp8_mean']:>8.3f}ms " + f"INT8:{result['speedup_int8_dequant']:>5.2f}x FP8:{result['speedup_fp8']:>5.2f}x {result['mem_reduction']:>6.2f}x") + else: + print(f"{layer_name:<30s} {result['fp16_mean']:>10.3f} {result['int8_dequant_mean']:>10.3f} {result['speedup_int8_dequant']:>8.2f}x {result['mem_reduction']:>8.2f}x") + + # Calculate per-model total + model_fp16_time = sum(r["fp16_mean"] for r in model_results) + model_int8_time = sum(r["int8_dequant_mean"] for r in model_results) + model_speedup_int8 = model_fp16_time / model_int8_time if model_int8_time > 0 else 0 + + print("-" * 80) + if has_fp8 and any(r['fp8_mean'] is not None for r in model_results): + model_fp8_time = sum(r["fp8_mean"] for r in model_results if r["fp8_mean"] is not None) + model_speedup_fp8 = model_fp16_time / model_fp8_time if model_fp8_time > 0 else 0 + print(f"{'SUBTOTAL':<25s} {model_fp16_time:>8.3f}ms {model_int8_time:>8.3f}ms {model_fp8_time:>8.3f}ms " + f"INT8:{model_speedup_int8:>5.2f}x FP8:{model_speedup_fp8:>5.2f}x") + else: + print(f"{'SUBTOTAL':<30s} {model_fp16_time:>10.3f} {model_int8_time:>10.3f} {model_speedup_int8:>8.2f}x") + + print(f" {model_name} avg memory reduction: {sum(r['mem_reduction'] for r in model_results) / len(model_results):.2f}x") + print(f" {model_name} avg INT8 precision error: {sum(r['error_int8'] for r in model_results) / len(model_results):.6f}") + if has_fp8 and any(r['error_fp8'] is not None for r in model_results): + fp8_errors = [r['error_fp8'] for r in model_results if r['error_fp8'] is not None] + if fp8_errors: + print(f" {model_name} avg FP8 precision error: {sum(fp8_errors) / len(fp8_errors):.6f}") + + # VRAM analysis + total_fp16_vram = sum(r['fp16_vram'] for r in model_results) + total_int8_vram = sum(r['int8_vram'] for r in model_results) + print(f" {model_name} VRAM usage: FP16 {total_fp16_vram:.2f}MB, INT8 {total_int8_vram:.2f}MB (during inference with both)") + + # Calculate overall totals + total_fp16_time = sum(r["fp16_mean"] for r in results) + total_int8_time = sum(r["int8_dequant_mean"] for r in results) + overall_speedup_int8 = total_fp16_time / total_int8_time if total_int8_time > 0 else 0 + + print("\n" + "=" * 80) + if has_fp8 and any(r['fp8_mean'] is not None for r in results): + total_fp8_time = sum(r["fp8_mean"] for r in results if r["fp8_mean"] is not None) + overall_speedup_fp8 = total_fp16_time / total_fp8_time if total_fp8_time > 0 else 0 + print(f"{'GRAND TOTAL':<25s} {total_fp16_time:>8.3f}ms {total_int8_time:>8.3f}ms {total_fp8_time:>8.3f}ms " + f"INT8:{overall_speedup_int8:>5.2f}x FP8:{overall_speedup_fp8:>5.2f}x") + else: + print(f"{'GRAND TOTAL':<30s} {total_fp16_time:>10.3f} {total_int8_time:>10.3f} {overall_speedup_int8:>8.2f}x") + print("=" * 80) + + print(f"\n✓ Overall INT8 speedup: {overall_speedup_int8:.2f}x faster than FP16") + if has_fp8 and any(r['fp8_mean'] is not None for r in results): + print(f"✓ Overall FP8 speedup: {overall_speedup_fp8:.2f}x faster than FP16") + print(f"✓ Average memory reduction: {sum(r['mem_reduction'] for r in results) / len(results):.2f}x") + print(f"✓ Average INT8 precision error: {sum(r['error_int8'] for r in results) / len(results):.6f}") + if has_fp8: + fp8_errors = [r['error_fp8'] for r in results if r['error_fp8'] is not None] + if fp8_errors: + print(f"✓ Average FP8 precision error: {sum(fp8_errors) / len(fp8_errors):.6f}") + + # Total VRAM + total_fp16_vram = sum(r['fp16_vram'] for r in results) + total_int8_vram = sum(r['int8_vram'] for r in results) + print(f"✓ Total VRAM: FP16 {total_fp16_vram:.2f}MB, INT8 {total_int8_vram:.2f}MB") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + self.assertGreater(overall_speedup_int8, 0.5, "INT8 should have reasonable performance") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_systematic_benchmark(self): + """Comprehensive systematic benchmark across multiple dimensions""" + device = torch.device('cuda') + torch.manual_seed(42) + + n_warmup = 10 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Test 1: Varying batch size (typical transformer forward pass) + print("=" * 60) + print("Dimension 1: Varying Batch Size") + print("=" * 60) + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] + seq_len = 64 + in_features = 1024 + out_features = 1024 + block_size = 128 + + for batch_size in batch_sizes: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"Batch={batch_size:3d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Batch={batch_size:3d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 2: Varying sequence length + print("=" * 60) + print("Dimension 2: Varying Sequence Length") + print("=" * 60) + seq_lengths = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + batch_size = 8 + in_features = 1024 + out_features = 1024 + block_size = 128 + + for seq_len in seq_lengths: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"SeqLen={seq_len:4d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"SeqLen={seq_len:4d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 3: Varying hidden dimensions + print("=" * 60) + print("Dimension 3: Varying Hidden Dimensions") + print("=" * 60) + hidden_dims = [256, 512, 768, 1024, 1536, 2048, 3072, 4096, 8192] + batch_size = 8 + seq_len = 64 + block_size = 128 + + for hidden_dim in hidden_dims: + try: + input_fp32 = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(hidden_dim, hidden_dim, dtype=torch.float32, device=device) + bias = torch.randn(hidden_dim, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = hidden_dim + n = hidden_dim + flops = 2 * m * n * k + + print(f"Hidden={hidden_dim:4d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Hidden={hidden_dim:4d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 4: Varying block size + print("=" * 60) + print("Dimension 4: Varying Block Size") + print("=" * 60) + block_sizes = [32, 64, 128, 256, 512] + batch_size = 8 + seq_len = 64 + in_features = 1024 + out_features = 1024 + + for block_size in block_sizes: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + int8_std = torch.tensor(int8_times).std().item() + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"Block={block_size:3d}: INT8 {int8_mean:.3f}±{int8_std:.3f}ms, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Block={block_size:3d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + print("✓ Systematic benchmark completed!") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_benchmark(self): + """Benchmark INT8 GELU vs FP16 GELU""" + # See test_int8_gelu.py::benchmark_int8_gelu for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_systematic_benchmark(self): + """Systematic GELU benchmark across different dimensions""" + # See test_int8_gelu.py::benchmark_int8_gelu_systematic for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_real_model_sizes(self): + """Test FP16 vs INT8 GELU on actual model sizes""" + # See test_int8_gelu.py::test_fp16_vs_int8_real_model_sizes for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_quant_fusion_performance(self): + """Compare performance of fused vs separate quantization""" + # See test_int8_quant_fusion.py::test_performance_comparison for full implementation + self.skipTest("Benchmark test - run separately") + + if __name__ == "__main__": unittest.main()