diff --git a/QUANTIZATION.md b/QUANTIZATION.md index 1693e13f32e2..177651a20235 100644 --- a/QUANTIZATION.md +++ b/QUANTIZATION.md @@ -124,6 +124,10 @@ We define 4 possible scaling parameters that should cover most recipes in the ne | Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | |--------|---------------|--------------|----------------|-----------------|-------------| | float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | +| int8_blockwise | int8 | float32 (per-block) | - | - | - | + +For int8_blockwise with block_size=128 and weight shape (N, K): +- weight_scale shape: (N//128, K//128) You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). @@ -131,7 +135,9 @@ You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). The metadata stored alongside the checkpoint contains: - **format_version**: String to define a version of the standard -- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. +- **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with: + - **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS` + - **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise) Example: ```json @@ -139,9 +145,9 @@ Example: "_quantization_metadata": { "format_version": "1.0", "layers": { - "model.layers.0.mlp.up_proj": "float8_e4m3fn", - "model.layers.0.mlp.down_proj": "float8_e4m3fn", - "model.layers.1.mlp.up_proj": "float8_e4m3fn" + "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"}, + "model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128}, + "model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256} } } } diff --git a/comfy/float.py b/comfy/float.py index 521316fd2fac..3caa22909d45 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0): return value.to(dtype=torch.float16) if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) + if dtype == torch.int8: + return value.to(dtype=torch.int8) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) diff --git a/comfy/int8_kernels.py b/comfy/int8_kernels.py new file mode 100644 index 000000000000..8573bd5de4ee --- /dev/null +++ b/comfy/int8_kernels.py @@ -0,0 +1,1194 @@ +import torch +import triton +import triton.language as tl +from triton import Config +from typing import Tuple + + +""" +simplified explanation of the scaled int8 matmul algorithm +adopted from deepseek scaled FP8 matmul and jetfire paper +https://arxiv.org/abs/2403.12422 +https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py + + N dimension → + INT8 weights scaler per block + ┌-----┬-----┬─────┬─────┐ ┌-----┬-----┬─────┬─────┐ + : b00 : b01 : b02 | b03 | : : : | | + ├-----┼-----┼─────┼─────┤ :b_s00:b_s10:b_s20|b_s30| + K : b10 : b11 : b12 | b13 | : : : | | + dim ├-----┼-----┼─────┼─────┤ ├-----┼-----┼─────┼─────┤ + ↓ | b20 | b21 | b22 | b23 | | | | | | + ├─────┼─────┼─────┼─────┤ |b_s01|b_s11|b_s21|b_s31| + | b30 | b31 | b32 | b33 | | | | | | + └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ + ┌-----┬-----┐ + : b00 : b01 : + ├─── blk ───┤ ├-----┼-----┤ + : b10 : b11 : + K dimension → └-----┴-----┘ + INT8 activations + ┌-----┬-----┬─────┬─────┐ ┌-----┬-----┐ ┌-----┬-----┐ ┌-----------┐ ┌-----┬-----┐ ┌-----┬-----┐ + : a00 : a01 : a02 | a03 | : a00 : a01 : : @ : @ : : a_s00 : : : : :acc00:acc01: + ├-----┼-----┼─────┼─────┤ ├-----┼-----┤ ├-----┼-----┤ * ├-----------┤ * :b_s00:b_s10: = ├-----┼-----┤ + M : a10 : a11 : a12 | a13 | : a10 : a11 : : @ : @ : : a_s10 : : : : :acc10:acc11: +dim ├-----┼-----┼─────┼─────┤ └-----┴-----┘ └-----┴-----┘ └-----------┘ └-----┴-----┘ └-----┴-----┘ + ↓ | a20 | a21 | a22 | a23 | INT8 matmul acc in INT32 rescale the FP32 intermediate accumulate + ├─────┼─────┼─────┼─────┤ then cast to FP32 "rank 1" hadamard scaler intermediate + | a30 | a31 | a32 | a33 | + └─────┴─────┴─────┴─────┘ + scaler per block + ┌-----------┬───────────┐ + : a_s00 : a_s01 | + ├-----------┼───────────┤ + : a_s10 : a_s11 | + ├-----------┼───────────┤ + | a_s20 | a_s21 | + ├───────────┼───────────┤ + | a_s30 | a_s31 | + └───────────┴───────────┘ +""" + + +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + amax = tl.max(tl.abs(x)) # reduction + # amax = tl.maximum(amax, 1e-4) # clamp to 1e-4 + s = amax / 127.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.int8`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + y = torch.empty_like(x, dtype=torch.int8) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + # Grid size should match number of scale elements (one program per block) + # Each program processes block_size elements and writes one scale value + num_programs = s.numel() # Number of blocks = number of scale elements + grid = lambda meta: (num_programs,) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def act_dequant_kernel(x_ptr, s_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes the input tensor `x_ptr` using scaling factors from `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the quantized input tensor. + s_ptr (triton.Pointer): Pointer to the scaling factors. + y_ptr (triton.Pointer): Pointer to the output tensor where dequantized values will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.load(s_ptr + pid) + y = x * s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + + +def act_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128, output_dtype: torch.dtype = None +) -> torch.Tensor: + """ + Dequantizes the activation tensor `x` using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized activation tensor. Must be contiguous and its last dimension size must be divisible by `block_size`. + s (torch.Tensor): The scale tensor with shape (*batch_dims, last_dim // block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + output_dtype (torch.dtype, optional): Target dtype for output. Defaults to torch.get_default_dtype(). + + Returns: + torch.Tensor: The dequantized activation tensor of the same shape as `x`. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + if output_dtype is None: + output_dtype = torch.get_default_dtype() + + y = torch.empty_like(x, dtype=output_dtype) + # Grid size should match number of scale elements (one program per block) + num_programs = s.numel() # Number of blocks = number of scale elements + grid = lambda meta: (num_programs,) + act_dequant_kernel[grid](x, s, y, BLOCK_SIZE=block_size) + return y + + +@triton.jit +def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Quantizes weights using block-wise quantization. + + Args: + x_ptr (tl.pointer): Pointer to the input weights. + y_ptr (tl.pointer): Pointer to the output buffer for quantized weights. + s_ptr (tl.pointer): Pointer to the output buffer for scaling factors. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for quantization. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) + + # Compute per-block absolute maximum + amax = tl.max(tl.abs(x)) + s = amax / 127.0 + #s = tl.maximum(s, 1e-8) # Prevent division by zero + + # Quantize + y = x / s + #y = tl.maximum(tl.minimum(y, 127.0), -127.0) # Clamp + y = y.to(y_ptr.dtype.element_ty) + + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + +def weight_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the weight tensor using block-wise quantization. + + Args: + x (torch.Tensor): The weight tensor of shape (M, N). + block_size (int, optional): The block size to use for quantization. Defaults to 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.int8`. + - A tensor of scaling factors with shape (M//block_size, N//block_size) and dtype `torch.float32`. + + Raises: + AssertionError: If `x` is not contiguous or if its dimensions are not 2. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + M, N = x.size() + assert M % block_size == 0 and N % block_size == 0, \ + f"Dimensions must be divisible by block_size={block_size}, got shape {x.shape}" + + y = torch.empty_like(x, dtype=torch.int8) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128, output_dtype: torch.dtype = None +) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + output_dtype (torch.dtype, optional): Target dtype for output. Defaults to torch.get_default_dtype(). + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + + if output_dtype is None: + output_dtype = torch.get_default_dtype() + + y = torch.empty_like(x, dtype=output_dtype) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +# matmul intermediate block size is hardcoded to 128 +int8_gemm_configs = [ + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [128, 256] # >= 128 for consistency with out_block_size + for block_n in [128, 256] # >= 128 required for out_block_size compatibility + for num_stages in [3, 4, 5] +] + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.jit +def int8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + Performs a matrix multiplication operation on INT8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + # Address calculation: scale[pid_n, i] = base + pid_n * stride + i + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Create accumulators outside the loop for better performance + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + # Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + # INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling + dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32 + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.jit +def int8_gemm_addmm_kernel( + a_ptr, + b_ptr, + c_ptr, + bias_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with bias addition (addmm). + Computes: C = A @ B + bias + + This kernel fuses the bias addition into the matmul, avoiding an extra memory write/read cycle. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A (INT8). + b_ptr (tl.tensor): Pointer to the second input matrix B (INT8). + c_ptr (tl.tensor): Pointer to the output matrix C. + bias_ptr (tl.tensor): Pointer to the bias vector (1D, length N). + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + HAS_BIAS (tl.constexpr): Whether bias is provided. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + # Address calculation: scale[pid_n, i] = base + pid_n * stride + i + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + # Create accumulators outside the loop for better performance + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + # Load int8 data - use other=0 (int) not 0.0 (float) to preserve int8 type + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + # INT8 matmul → INT32 acc, then cast to FP32 and apply per-block scaling + dot_prod = tl.dot(a, b, out_dtype=tl.int32) # int8 × int8 → int32 + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Add bias if provided (fused operation) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_n[None, :] + bias = tl.load(bias_ptrs, mask=offs_n[None, :] < N, other=0.0) + accumulator += bias # Broadcast bias across M dimension + + # Store result + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def int8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using INT8 precision. + + Expected tensor shapes: + - a: [..., K] where ... can be any batch dimensions + - b: [N, K] (weight matrix in standard format, kernel transposes internally) + - a_s: [..., K//block_size] + - b_s: [N//block_size, K//block_size] + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix [N, K], must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert ( + a_s.is_contiguous() and b_s.is_contiguous() + ), "Scaling factor tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + # b has shape [N, K], extract N from first dimension + N = b.shape[0] + + # Validate shapes + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + + # Output tensor (same batch shape as input, last dim = N) + # let's use float16 as output dtype + c = a.new_empty(*a.size()[:-1], N, dtype=torch.float16) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + int8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128) + return c + + +def int8_addmm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + bias: torch.Tensor = None +): + """ + Fused INT8 matrix multiplication with bias addition (addmm). + Computes: output = (a @ b) + bias + + Expected tensor shapes: + - a: [..., K] where ... can be any batch dimensions + - b: [N, K] (weight matrix in standard format, kernel transposes internally) + - a_s: [..., K//block_size] + - b_s: [N//block_size, K//block_size] + - bias: [N] (optional) + + This is more efficient than separate matmul + bias add operations as it: + 1. Avoids an extra memory write/read cycle + 2. Fuses the bias addition into the matmul kernel + 3. Better utilizes GPU memory bandwidth + + Args: + a (torch.Tensor): The first input matrix (INT8), must be contiguous. + a_s (torch.Tensor): The scaling factors for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix (INT8) [N, K], must be contiguous. + b_s (torch.Tensor): The scaling factors for the second input matrix, must be contiguous. + bias (torch.Tensor, optional): The bias vector (1D, length N). If None, only matmul is performed. + + Returns: + torch.Tensor: The result of the fused matrix multiplication and bias addition. + + Example: + >>> a_int8, a_scale = act_quant(input_tensor, block_size=128) + >>> b_int8, b_scale = weight_quant(weight_tensor, block_size=128) + >>> bias = torch.randn(output_features) + >>> output = int8_addmm(a_int8, a_scale, b_int8, b_scale, bias) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert ( + a_s.is_contiguous() and b_s.is_contiguous() + ), "Scaling factor tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + # b has shape [N, K], extract N from first dimension + N = b.shape[0] + + # Validate shapes + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + + # Output tensor (same batch shape as input, last dim = N) + # let's use float16 as output dtype + c = a.new_empty(*a.size()[:-1], N, dtype=torch.float16) + + # Handle bias + has_bias = bias is not None + if has_bias: + assert bias.is_contiguous(), "Bias tensor must be contiguous" + assert bias.dim() == 1 and bias.size(0) == N, \ + f"Bias must be 1D with length {N}, got shape {bias.shape}" + bias_ptr = bias + else: + # Create a dummy pointer (won't be used due to HAS_BIAS=False) + bias_ptr = c + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + int8_gemm_addmm_kernel[grid]( + a, b, c, bias_ptr, a_s, b_s, M, N, K, HAS_BIAS=has_bias, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + return c + + +# ============================================================================== +# Fused INT8 GEMM + Quantization Kernels +# ============================================================================== +# +# Architecture Overview: +# ---------------------- +# 1. Kernels compute matmul and quantize PER-ROW for activation format +# - Each row gets its own scale for the N-range of the tile +# - Kernel output: c_scale shape is (M, N/BLOCK_SIZE_N) +# - BLOCK_SIZE_M, BLOCK_SIZE_N are tile sizes from autotuner (e.g., 16-64, 32-128) +# - This matches activation quantization: per-row, block-wise along N +# +# 2. Wrapper functions convert to final activation format +# - Kernel output: (M, N/BLOCK_SIZE_N) +# - Target format: (*batch_dims, N/out_block_size) +# - If BLOCK_SIZE_N == out_block_size: already correct, just reshape +# - If BLOCK_SIZE_N != out_block_size: replicate or merge scales +# +# 3. Benefits: +# - Accurate: per-row scales match activation quantization format +# - Efficient: single max reduction per row per tile +# - Compatible: direct output in activation format +# - Better precision: each row has independent scales +# +# ============================================================================== + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.heuristics({ + 'NUM_BLOCKS': lambda args: args["BLOCK_SIZE_N"] // args["out_block_size"], +}) +@triton.jit +def int8_gemm_quant_kernel( + a_ptr, + b_ptr, + c_ptr, + c_s_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + out_block_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_BLOCKS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with output quantization. + Computes: C_int8, C_scale = quantize(A @ B) + + This kernel fuses matmul and block-wise quantization in a single pass. + Quantizes at out_block_size granularity (like act_quant_kernel). + + Args: + a_ptr: Pointer to INT8 activations + b_ptr: Pointer to INT8 weights + c_ptr: Pointer to INT8 output + c_s_ptr: Pointer to output scales (shape: M x N/out_block_size) + a_s_ptr: Pointer to activation scales + b_s_ptr: Pointer to weight scales + M: Number of rows in A and C + N: Number of columns in B and C + K: Inner dimension (columns in A, rows in B) + out_block_size: Block size for output quantization + BLOCK_SIZE_M/N/K: Tile sizes for matmul + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + dot_prod = tl.dot(a, b, out_dtype=tl.int32) + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Quantize in activation format: per-row, block-wise at out_block_size granularity + # Reshape accumulator to separate blocks: (BLOCK_SIZE_M, BLOCK_SIZE_N) -> (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) + accumulator_reshaped = tl.reshape(accumulator, (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size)) + + # Compute max per block: reduce over out_block_size dimension + # Shape: (BLOCK_SIZE_M, NUM_BLOCKS) + block_max = tl.max(tl.abs(accumulator_reshaped), axis=2) + block_scale = tl.maximum(block_max / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_SIZE_M, NUM_BLOCKS) -> (BLOCK_SIZE_M, NUM_BLOCKS, 1) + block_scale_broadcast = tl.reshape(block_scale, (BLOCK_SIZE_M, NUM_BLOCKS, 1)) + + # Quantize: accumulator -> int8 + quantized = accumulator_reshaped / block_scale_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(c_ptr.dtype.element_ty) + + # Reshape back to 2D: (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) -> (BLOCK_SIZE_M, BLOCK_SIZE_N) + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Store quantized output + offs_m_actual = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_actual = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_m_actual[:, None] < M) & (offs_n_actual[None, :] < N) + c_ptrs = c_ptr + offs_m_actual[:, None] * N + offs_n_actual[None, :] + tl.store(c_ptrs, quantized_int8, mask=mask) + + # Store scales: (BLOCK_SIZE_M, NUM_BLOCKS) scales for this tile + # Scale layout: (M, N//out_block_size) - matches activation format directly! + # This tile covers M range [pid_m*BLOCK_SIZE_M : (pid_m+1)*BLOCK_SIZE_M] + # N range [pid_n*BLOCK_SIZE_N : (pid_n+1)*BLOCK_SIZE_N] + # N block indices: [pid_n * NUM_BLOCKS : (pid_n+1) * NUM_BLOCKS] + n_scale_stride = N // out_block_size # Total number of N blocks + offs_m_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_scale = pid_n * NUM_BLOCKS + tl.arange(0, NUM_BLOCKS) + scale_ptrs = c_s_ptr + offs_m_scale[:, None] * n_scale_stride + offs_n_scale[None, :] + scale_mask = (offs_m_scale[:, None] < M) & (offs_n_scale[None, :] < n_scale_stride) + tl.store(scale_ptrs, block_scale, mask=scale_mask) + + +#@triton.autotune(configs=int8_gemm_configs, key=["N", "K"]) +@triton.heuristics({ + 'NUM_BLOCKS': lambda args: args["BLOCK_SIZE_N"] // args["out_block_size"], +}) +@triton.jit +def int8_gemm_addmm_quant_kernel( + a_ptr, + b_ptr, + c_ptr, + c_s_ptr, + bias_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + out_block_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_BLOCKS: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """ + Fused INT8 matrix multiplication with bias addition and output quantization. + Computes: C_int8, C_scale = quantize(A @ B + bias) + + This kernel fuses matmul, bias addition, and block-wise quantization. + Quantizes at out_block_size granularity (like act_quant_kernel). + + Args: + a_ptr: Pointer to INT8 activations + b_ptr: Pointer to INT8 weights + c_ptr: Pointer to INT8 output + c_s_ptr: Pointer to output scales (shape: M x N/out_block_size) + bias_ptr: Pointer to bias vector + a_s_ptr: Pointer to activation scales + b_s_ptr: Pointer to weight scales + M: Number of rows in A and C + N: Number of columns in B and C + K: Inner dimension + out_block_size: Block size for output quantization + BLOCK_SIZE_M/N/K: Tile sizes for matmul + HAS_BIAS: Whether bias is provided + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + + # FIXED: Weight scale indexing for 2D scale array (N_blocks, K_blocks) + # b_s has shape (N//BLOCK_SIZE_K, K//BLOCK_SIZE_K) stored in row-major + # For N tile pid_n, we need scales[pid_n, :] across K iterations + k_blocks = k # Number of K blocks for clarity + b_s_base = b_s_ptr + pid_n * k_blocks + + # Accumulate matmul result + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k_blocks): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0) + a_s = tl.load(a_s_ptrs) + # FIXED: Load single scalar weight scale for (pid_n, i) block pair + b_s = tl.load(b_s_base + i) + dot_prod = tl.dot(a, b, out_dtype=tl.int32) + accumulator += dot_prod.to(tl.float32) * a_s[:, None] * b_s + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + + # Add bias if provided + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_n[None, :] + bias = tl.load(bias_ptrs, mask=offs_n[None, :] < N, other=0.0) + accumulator += bias + + # Quantize in activation format: per-row, block-wise at out_block_size granularity + # Reshape accumulator to separate blocks: (BLOCK_SIZE_M, BLOCK_SIZE_N) -> (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) + accumulator_reshaped = tl.reshape(accumulator, (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size)) + + # Compute max per block: reduce over out_block_size dimension + # Shape: (BLOCK_SIZE_M, NUM_BLOCKS) + block_max = tl.max(tl.abs(accumulator_reshaped), axis=2) + block_scale = tl.maximum(block_max / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_SIZE_M, NUM_BLOCKS) -> (BLOCK_SIZE_M, NUM_BLOCKS, 1) + block_scale_broadcast = tl.reshape(block_scale, (BLOCK_SIZE_M, NUM_BLOCKS, 1)) + + # Quantize: accumulator -> int8 + quantized = accumulator_reshaped / block_scale_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(c_ptr.dtype.element_ty) + + # Reshape back to 2D: (BLOCK_SIZE_M, NUM_BLOCKS, out_block_size) -> (BLOCK_SIZE_M, BLOCK_SIZE_N) + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Store quantized output + offs_m_actual = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_actual = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_m_actual[:, None] < M) & (offs_n_actual[None, :] < N) + c_ptrs = c_ptr + offs_m_actual[:, None] * N + offs_n_actual[None, :] + tl.store(c_ptrs, quantized_int8, mask=mask) + + # Store scales: (BLOCK_SIZE_M, NUM_BLOCKS) scales for this tile + # Scale layout: (M, N//out_block_size) - matches activation format directly! + # This tile covers M range [pid_m*BLOCK_SIZE_M : (pid_m+1)*BLOCK_SIZE_M] + # N range [pid_n*BLOCK_SIZE_N : (pid_n+1)*BLOCK_SIZE_N] + # N block indices: [pid_n * NUM_BLOCKS : (pid_n+1) * NUM_BLOCKS] + n_scale_stride = N // out_block_size # Total number of N blocks + offs_m_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_scale = pid_n * NUM_BLOCKS + tl.arange(0, NUM_BLOCKS) + scale_ptrs = c_s_ptr + offs_m_scale[:, None] * n_scale_stride + offs_n_scale[None, :] + scale_mask = (offs_m_scale[:, None] < M) & (offs_n_scale[None, :] < n_scale_stride) + tl.store(scale_ptrs, block_scale, mask=scale_mask) + + +def int8_gemm_quant( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + out_block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 GEMM with output quantization. + Computes: C_int8, C_scale = quantize(A @ B) + + This avoids materializing the full-precision intermediate result. + + The kernel produces scales in activation format directly: (*batch_dims, N/out_block_size). + + Args: + a: INT8 activations [..., K] + a_s: Activation scales [..., K//block_size] + b: INT8 weights [N, K] + b_s: Weight scales [N//block_size, K//block_size] + out_block_size: Block size for output quantization (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales in activation format) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + N = b.shape[0] + batch_shape = a.size()[:-1] + + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + assert N % out_block_size == 0, f"N={N} must be divisible by out_block_size={out_block_size}" + + # Allocate output tensors + c = a.new_empty(*batch_shape, N, dtype=torch.int8) + + # Allocate scales in activation format directly: (M, N//out_block_size) + n_blocks = N // out_block_size + c_s = a.new_empty(M, n_blocks, dtype=torch.float32) + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + int8_gemm_quant_kernel[grid]( + a, b, c, c_s, a_s, b_s, M, N, K, out_block_size, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + + # Reshape scales to match batch dimensions: (M, n_blocks) -> (*batch_dims, n_blocks) + if len(batch_shape) > 0: + c_s = c_s.reshape(*batch_shape, n_blocks) + + return c, c_s + + +def int8_addmm_quant( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + bias: torch.Tensor = None, + out_block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 addmm with output quantization. + Computes: C_int8, C_scale = quantize(A @ B + bias) + + This fuses matmul, bias addition, and quantization in a single kernel pass. + + The kernel produces scales in activation format directly: (*batch_dims, N/out_block_size). + + Args: + a: INT8 activations [..., K] + a_s: Activation scales [..., K//block_size] + b: INT8 weights [N, K] + b_s: Weight scales [N//block_size, K//block_size] + bias: Optional bias vector [N] + out_block_size: Block size for output quantization (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales in activation format) + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling tensors must be contiguous" + assert b.dim() == 2, f"Expected b to be 2D, got shape {b.shape}" + + K = a.size(-1) + M = a.numel() // K + N = b.shape[0] + batch_shape = a.size()[:-1] + + assert b.size(1) == K, f"Shape mismatch: b.shape={b.shape}, expected [..., {K}]" + assert N % out_block_size == 0, f"N={N} must be divisible by out_block_size={out_block_size}" + + # Allocate output tensors + c = a.new_empty(*batch_shape, N, dtype=torch.int8) + + # Allocate scales in activation format directly: (M, N//out_block_size) + n_blocks = N // out_block_size + c_s = a.new_empty(M, n_blocks, dtype=torch.float32) + + # Handle bias + has_bias = bias is not None + if has_bias: + assert bias.is_contiguous(), "Bias tensor must be contiguous" + assert bias.dim() == 1 and bias.size(0) == N, \ + f"Bias must be 1D with length {N}, got shape {bias.shape}" + bias_ptr = bias + else: + bias_ptr = c # Dummy pointer + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + int8_gemm_addmm_quant_kernel[grid]( + a, b, c, c_s, bias_ptr, a_s, b_s, M, N, K, out_block_size, HAS_BIAS=has_bias, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128 + ) + + # Reshape scales to match batch dimensions: (M, n_blocks) -> (*batch_dims, n_blocks) + if len(batch_shape) > 0: + c_s = c_s.reshape(*batch_shape, n_blocks) + + return c, c_s + + +# ============================================================================== +# INT8 GELU Kernel +# ============================================================================== + +# Autotuning configs for GELU kernel +# Note: BLOCK_N must be >= quantization block_size (typically 128) and divisible by it +# BLOCK_M can be any size since we don't block in M dimension for activations +int8_gelu_configs = [ + Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_m in [64, 128, 256] + for block_n in [128, 256] # Must be >= block_size and divisible by it + for num_stages in [2, 3, 4] + for num_warps in [4, 8] +] + + +#@triton.autotune(configs=int8_gelu_configs, key=["M", "N"]) +@triton.heuristics({ + 'BLOCK_SM': lambda args: args["BLOCK_M"], # For activations, no blocking in M dimension + 'BLOCK_SN': lambda args: args["BLOCK_N"] // args["BLOCK_SIZE"], +}) +@triton.jit +def int8_gelu_kernel( + output_ptr, + output_scale_ptr, + input_ptr, + input_scale_ptr, + M, + N: tl.constexpr, + SM, + SN: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SM: tl.constexpr, + BLOCK_SN: tl.constexpr, +): + """ + Fused INT8 GELU with block-wise quantization. + + Computes: output_int8, output_scale = quantize(gelu(dequantize(input_int8, input_scale))) + + For activation quantization, we only block along the last dimension (N). + Each row gets its own set of scales along N. + + Scale tensor layout: + - Input scales: (M, N // BLOCK_SIZE) - one scale per row per block in N + - Within each tile (BLOCK_M x BLOCK_N), we load (BLOCK_M, BLOCK_N // BLOCK_SIZE) scales + + This kernel: + 1. Loads INT8 input and its block-wise scales + 2. Dequantizes to float + 3. Applies GELU activation + 4. Quantizes output back to INT8 with new block-wise scales + + Args: + output_ptr: Pointer to INT8 output tensor + output_scale_ptr: Pointer to output scales + input_ptr: Pointer to INT8 input tensor + input_scale_ptr: Pointer to input scales + M: Number of rows + N: Number of columns + SM: Number of rows in scale tensor (= M for activations) + SN: Number of scale blocks in N dimension (= N // BLOCK_SIZE) + BLOCK_SIZE: Quantization block size (e.g., 128) + BLOCK_M: Tile size in M dimension + BLOCK_N: Tile size in N dimension + BLOCK_SM: Number of rows per tile (= BLOCK_M for activations) + BLOCK_SN: Number of scale blocks per tile in N dimension (= BLOCK_N // BLOCK_SIZE) + """ + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_m = pid // NUM_BLOCK_N + pid_n = pid % NUM_BLOCK_N + + # Offsets for data + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load input data + input_ptrs = input_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + input_data = tl.load(input_ptrs, mask=mask, other=0).to(tl.int32) + + # Load input scales + # Scale dimensions: (SM, SN) where SM = M, SN = N // BLOCK_SIZE + # For this tile: load (BLOCK_M, BLOCK_N // BLOCK_SIZE) scales + offs_sm = pid_m * BLOCK_SM + tl.arange(0, BLOCK_SM) + offs_sn = pid_n * BLOCK_SN + tl.arange(0, BLOCK_SN) + scale_ptrs = input_scale_ptr + offs_sm[:, None] * SN + offs_sn[None, :] + scale_mask = (offs_sm[:, None] < SM) & (offs_sn[None, :] < SN) + input_scales = tl.load(scale_ptrs, mask=scale_mask, other=1.0) + + # Reshape for broadcasting + # Data: (BLOCK_M, BLOCK_N) -> (BLOCK_M, BLOCK_SN, BLOCK_SIZE) + # Scales: (BLOCK_M, BLOCK_SN) -> (BLOCK_M, BLOCK_SN, 1) + input_data = tl.reshape(input_data, (BLOCK_M, BLOCK_SN, BLOCK_SIZE)) + input_scales = tl.reshape(input_scales, (BLOCK_M, BLOCK_SN, 1)) + + # Dequantize + input_fp32 = input_data.to(tl.float32) * input_scales + + # Apply GELU: 0.5 * x * (1 + erf(x / sqrt(2))) + sqrt_2 = 1.41421356237 + erf_input = input_fp32 / sqrt_2 + erf_val = tl.math.erf(erf_input) + gelu_output = input_fp32 * 0.5 * (1.0 + erf_val) + + # Compute output scales per block + # Shape: (BLOCK_M, BLOCK_SN, BLOCK_SIZE) -> (BLOCK_M, BLOCK_SN) + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) # Reduce over BLOCK_SIZE dimension + output_scales = tl.maximum(max_val / 127.0, 1e-8) + + # Reshape scales for broadcasting: (BLOCK_M, BLOCK_SN) -> (BLOCK_M, BLOCK_SN, 1) + output_scales_broadcast = tl.reshape(output_scales, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize output + quantized = gelu_output / output_scales_broadcast + quantized = tl.maximum(tl.minimum(quantized, 127.0), -127.0) + quantized_int8 = quantized.to(tl.int8) + + # Reshape back to 2D + quantized_int8 = tl.reshape(quantized_int8, (BLOCK_M, BLOCK_N)) + + # Store quantized output + output_ptrs = output_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(output_ptrs, quantized_int8, mask=mask) + + # Store output scales + output_scale_ptrs = output_scale_ptr + offs_sm[:, None] * SN + offs_sn[None, :] + tl.store(output_scale_ptrs, output_scales, mask=scale_mask) + + +def int8_gelu( + x: torch.Tensor, + s_x: torch.Tensor, + block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused INT8 GELU activation with block-wise quantization. + + Computes: y_int8, y_scale = quantize(gelu(dequantize(x, s_x))) + + This avoids materializing the full-precision intermediate result. + + Args: + x: INT8 input tensor of any shape + s_x: Input scales with shape (*batch_dims, last_dim // block_size) + block_size: Quantization block size (default: 128) + + Returns: + Tuple of (quantized output INT8, output scales) + + Note: + The kernel requires tile sizes >= block_size. This is automatically + handled by the autotuner, which uses BLOCK_M, BLOCK_N >= 128. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert s_x.is_contiguous(), "Scale tensor must be contiguous" + assert x.size(-1) % block_size == 0, \ + f"Last dimension must be divisible by block_size={block_size}" + assert block_size == 128, \ + f"Only block_size=128 is currently supported in autotuner configs (got {block_size})" + + # Handle multi-dimensional tensors by reshaping to 2D + original_shape = x.shape + batch_shape = original_shape[:-1] + N = original_shape[-1] + + if x.dim() > 2: + x = x.reshape(-1, N) + s_x = s_x.reshape(-1, s_x.size(-1)) + + M = x.size(0) + SM = M # For activations, we don't block in M dimension + SN = N // block_size + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.int8) + s_y = torch.empty_like(s_x, dtype=torch.float32) + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + int8_gelu_kernel[grid]( + y, s_y, x, s_x, + M, N, SM, SN, + BLOCK_SIZE=block_size, BLOCK_M=128, BLOCK_N=128, BLOCK_SM=128 + ) + + # Reshape back to original batch dimensions + if len(batch_shape) > 0: + y = y.reshape(*batch_shape, N) + s_y = s_y.reshape(*batch_shape, SN) + + return y, s_y diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1710..9543334e6b85 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -599,11 +599,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + # Check for per-layer group_size override, otherwise use default from QUANT_ALGOS + layer_config = MixedPrecisionOps._layer_quant_config[layer_name] + group_size = layer_config.get("group_size", qconfig.get("group_size", None)) + layout_params = { 'scale': state_dict.pop(weight_scale_key, None), 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), + 'block_size': group_size, } + if qconfig.get("asymmetric_layout", False): + layout_params['is_weight'] = True if layout_params['scale'] is not None: manually_loaded_keys.append(weight_scale_key) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d2f3e7397839..ff07c76da931 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -6,6 +6,23 @@ _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} +# Try to import Triton-based INT8 kernels +try: + from .int8_kernels import ( + act_quant as triton_act_quant, + act_dequant as triton_act_dequant, + weight_quant as triton_weight_quant, + weight_dequant as triton_weight_dequant, + int8_gemm as triton_int8_gemm, + int8_addmm as triton_int8_addmm, + int8_gemm_quant as triton_int8_gemm_quant, + int8_addmm_quant as triton_int8_addmm_quant + ) + _HAS_TRITON_INT8 = True +except ImportError: + _HAS_TRITON_INT8 = False + logging.warning("Triton INT8 kernels not available, using PyTorch fallback") + def register_layout_op(torch_op, layout_type): """ @@ -212,9 +229,21 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return handler(func, args, kwargs) # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) + #if isinstance(args[0] if args else None, QuantizedTensor): + #logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}, args={args}") + + to_return = cls._dequant_and_fallback(func, args, kwargs) + + return to_return + + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self): + return self._qdata.is_contiguous() @classmethod def _dequant_and_fallback(cls, func, args, kwargs): @@ -229,14 +258,6 @@ def dequant_arg(arg): new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self): - return self._qdata.is_contiguous() # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) @@ -328,6 +349,17 @@ def generic_to_dtype_layout(func, args, kwargs): ) return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + @register_generic_util(torch.ops.aten.copy_.default) def generic_copy_(func, args, kwargs): @@ -347,18 +379,6 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True @@ -431,16 +451,261 @@ def dequantize(qdata, scale, orig_dtype, **kwargs): def get_plain_tensors(cls, qtensor): return qtensor._qdata, qtensor._layout_params['scale'] + +# ============================================================================== +# Block-Wise INT8 Layout + Operation Handlers +# ============================================================================== +class BlockWiseINT8Layout(QuantizedLayout): + """ + Block-wise INT8 quantization layout. + + Storage format: + - qdata: INT8 tensor (torch.int8) + - scale: Per-block scaling factors (float32) + - block_size: Size of quantization blocks (default 128) + - orig_dtype: Original dtype before quantization (for casting back) + - is_weight: Whether this is a weight tensor (affects blocking dimension) + + Asymmetric blocking: + - Weights: blocks partition along first dimension (M) and second dimension (N) + scale shape: (M//block_size, N//block_size) + - Activations: blocks partition along last dimension (K) + scale shape: (*batch_dims, K//block_size) + """ + + @classmethod + def quantize(cls, tensor, scale=None, block_size=128, is_weight=False, **kwargs): + """ + Quantize a tensor to INT8 with block-wise scaling. + + Args: + tensor: Input tensor to quantize + scale: Optional pre-computed scaling factors + block_size: Size of quantization blocks (default 128) + is_weight: If True, block along both dimensions (for weights) + If False, block along last dimension only (for activations) + + Returns: + Tuple of (quantized_data, layout_params) + """ + orig_dtype = tensor.dtype + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + if is_weight: + # Weight quantization: block-wise along both M and N dimensions + # Expected shape: (M, N) + assert tensor.dim() == 2, f"Weight tensor must be 2D, got shape {tensor.shape}" + M, N = tensor.shape + assert M % block_size == 0 and N % block_size == 0, \ + f"Dimensions must be divisible by block_size={block_size}, got shape {tensor.shape}" + + # Use Triton kernel if available AND tensor is on CUDA + if _HAS_TRITON_INT8 and scale is None and tensor.is_cuda: + try: + qdata, scale = triton_weight_quant(tensor, block_size=block_size) + except Exception as e: + # don't fall back, raise, for easier debugging + logging.warning(f"Triton weight_quant failed: {e}, falling back to PyTorch") + raise e + # qdata, scale = cls._weight_quantize_pytorch(tensor, block_size) + else: + qdata, scale = cls._weight_quantize_pytorch(tensor, block_size, scale) + + else: + # Activation quantization: block-wise along last dimension (K) + # Can handle any shape: (*batch_dims, K) + K = tensor.shape[-1] + assert K % block_size == 0, \ + f"Last dimension must be divisible by block_size={block_size}, got {K}" + + # Use Triton kernel if available AND tensor is on CUDA + # ignore input scale for now + # TODO: why do we need input scale? + if _HAS_TRITON_INT8 and tensor.is_cuda: + try: + qdata, scale = triton_act_quant(tensor, block_size=block_size) + except Exception as e: + logging.warning(f"Triton act_quant failed: {e}, falling back to PyTorch") + qdata, scale = cls._activation_quantize_pytorch(tensor, block_size) + else: + qdata, scale = cls._activation_quantize_pytorch(tensor, block_size, scale) + + layout_params = { + 'scale': scale.to(torch.float32), + 'block_size': block_size, + 'is_weight': is_weight, + 'orig_dtype': orig_dtype + } + + return qdata, layout_params + + @staticmethod + def _weight_quantize_pytorch(tensor, block_size, scale=None): + """PyTorch fallback for weight quantization""" + M, N = tensor.shape + # Reshape to (M//block_size, block_size, N//block_size, block_size) + tensor_blocked = tensor.reshape(M // block_size, block_size, N // block_size, block_size) + # Permute to (M//block_size, N//block_size, block_size, block_size) + tensor_blocked = tensor_blocked.permute(0, 2, 1, 3) + + if scale is None: + # Compute per-block absolute maximum + amax = tensor_blocked.abs().amax(dim=(-2, -1)) + scale = amax / 127.0 + scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype)) + + # Broadcast scale for division: (M//block_size, N//block_size, 1, 1) + scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1) + tensor_scaled = tensor_blocked / scale_broadcast + + # Clamp and convert to int8 + tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0) + qdata = tensor_scaled.to(torch.int8) + + # Reshape back to original shape + qdata = qdata.permute(0, 2, 1, 3).reshape(M, N) + return qdata, scale + + @staticmethod + def _activation_quantize_pytorch(tensor, block_size, scale=None): + """PyTorch fallback for activation quantization""" + K = tensor.shape[-1] + batch_shape = tensor.shape[:-1] + tensor_blocked = tensor.reshape(*batch_shape, K // block_size, block_size) + + if scale is None: + # Compute per-block absolute maximum + amax = tensor_blocked.abs().amax(dim=-1) + scale = amax / 127.0 + scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype)) + + # Broadcast scale for division + scale_broadcast = scale.unsqueeze(-1) + tensor_scaled = tensor_blocked / scale_broadcast + + # Clamp and convert to int8 + tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0) + qdata = tensor_scaled.to(torch.int8) + + # Reshape back to original shape + qdata = qdata.reshape(tensor.shape) + return qdata, scale + + @staticmethod + def dequantize(qdata, scale, block_size, is_weight=False, orig_dtype=None, output_dtype=None, **kwargs): + """ + Dequantize INT8 tensor back to original precision. + + Args: + qdata: Quantized INT8 tensor + scale: Per-block scaling factors + block_size: Size of quantization blocks + is_weight: Whether this is a weight tensor + orig_dtype: Target dtype for dequantization + + Returns: + Dequantized tensor in orig_dtype + """ + if not qdata.is_contiguous(): + qdata = qdata.contiguous() + if not scale.is_contiguous(): + scale = scale.contiguous() + + if is_weight: + # Weight dequantization + if _HAS_TRITON_INT8 and qdata.dim() == 2 and qdata.is_cuda: + try: + dequant = triton_weight_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype) + return dequant + except Exception as e: + logging.warning(f"Triton weight_dequant failed: {e}, falling back to PyTorch") + raise e + + # PyTorch fallback + M, N = qdata.shape + # Ensure scale has the correct shape for weight dequantization + expected_scale_shape = (M // block_size, N // block_size) + if scale.shape != expected_scale_shape: + expected_numel = (M // block_size) * (N // block_size) + if scale.numel() == expected_numel: + scale = scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Weight dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}" + ) + qdata_blocked = qdata.reshape(M // block_size, block_size, N // block_size, block_size) + qdata_blocked = qdata_blocked.permute(0, 2, 1, 3) + scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1) + dequant = qdata_blocked.to(orig_dtype) * scale_broadcast + dequant = dequant.permute(0, 2, 1, 3).reshape(M, N) + else: + # Activation dequantization + if _HAS_TRITON_INT8 and qdata.is_cuda: + try: + dequant = triton_act_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype) + return dequant + except Exception as e: + logging.warning(f"Triton act_dequant failed: {e}, falling back to PyTorch") + raise e + + # PyTorch fallback + batch_shape = qdata.shape[:-1] + K = qdata.shape[-1] + # Ensure scale has the correct shape for activation dequantization + expected_scale_shape = (*batch_shape, K // block_size) + if scale.shape != expected_scale_shape: + expected_numel = 1 + for dim in expected_scale_shape: + expected_numel *= dim + if scale.numel() == expected_numel: + scale = scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Activation dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}" + ) + qdata_blocked = qdata.reshape(*batch_shape, K // block_size, block_size) + scale_broadcast = scale.unsqueeze(-1) + dequant = qdata_blocked.to(orig_dtype) * scale_broadcast + dequant = dequant.reshape(qdata.shape) + + return dequant + + @classmethod + def get_plain_tensors(cls, qtensor): + """ + Extract raw tensors for computation. + + Returns: + Tuple of (qdata, scale, block_size, is_weight) + """ + return ( + qtensor._qdata, + qtensor._layout_params['scale'], + qtensor._layout_params['block_size'], + qtensor._layout_params['is_weight'] + ) + + QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, "comfy_tensor_layout": "TensorCoreFP8Layout", }, + "int8_blockwise": { + "storage_t": torch.int8, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "BlockWiseINT8Layout", + "group_size": 128, # Default block size, + "asymmetric_layout": True, + }, } LAYOUTS = { "TensorCoreFP8Layout": TensorCoreFP8Layout, + "BlockWiseINT8Layout": BlockWiseINT8Layout, } @@ -570,3 +835,671 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + + +# ============================================================================== +# Block-Wise INT8 Operation Handlers +# ============================================================================== + +def _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None): + """ + PyTorch fallback for INT8 matrix multiplication: dequantize and use standard matmul. + + Args: + a_int8: INT8 activations, shape (*batch, K) + a_scale: Activation scales, shape (*batch, K//block_size) + b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format + b_scale: Weight scales, shape (N//block_size, K//block_size) + block_size: Block size for quantization + bias: Optional bias vector, shape (N,) + + Returns: + Output in float32, shape (*batch, N) + """ + K = a_int8.shape[-1] + batch_shape = a_int8.shape[:-1] + N = b_int8.shape[0] + + # Dequantize activations + # Ensure a_scale has the correct shape - it should be (*batch_shape, K // block_size) + expected_scale_shape = (*batch_shape, K // block_size) + if a_scale.shape != expected_scale_shape: + # Try to reshape if the number of elements matches + expected_numel = 1 + for dim in expected_scale_shape: + expected_numel *= dim + if a_scale.numel() == expected_numel: + a_scale = a_scale.reshape(expected_scale_shape) + else: + raise RuntimeError( + f"Scale shape mismatch: a_scale.shape={a_scale.shape}, expected {expected_scale_shape}. " + + f"a_int8.shape={a_int8.shape}, K={K}, block_size={block_size}" + ) + + a_blocked = a_int8.reshape(*batch_shape, K // block_size, block_size) + a_scale_broadcast = a_scale.unsqueeze(-1) + a_fp32 = a_blocked.to(torch.float32) * a_scale_broadcast + a_fp32 = a_fp32.reshape(*batch_shape, K) + + # Dequantize weights + # b_int8 is in (N, K) format (standard weight format), b_scale is in (N//block_size, K//block_size) format + expected_weight_scale_shape = (N // block_size, K // block_size) + if b_scale.shape != expected_weight_scale_shape: + # Try to reshape if the number of elements matches + expected_weight_numel = (N // block_size) * (K // block_size) + if b_scale.numel() == expected_weight_numel: + b_scale = b_scale.reshape(expected_weight_scale_shape) + else: + raise RuntimeError( + f"Weight scale shape mismatch: b_scale.shape={b_scale.shape}, expected {expected_weight_scale_shape}. " + + f"b_int8.shape={b_int8.shape}, N={N}, K={K}, block_size={block_size}" + ) + + # Dequantize weight: (N, K) -> blocks -> dequantize -> (N, K) + b_blocked = b_int8.reshape(N // block_size, block_size, K // block_size, block_size) + b_blocked = b_blocked.permute(0, 2, 1, 3) # (N//bs, K//bs, bs, bs) + b_scale_broadcast = b_scale.unsqueeze(-1).unsqueeze(-1) + b_fp32 = b_blocked.to(torch.float32) * b_scale_broadcast + b_fp32 = b_fp32.permute(0, 2, 1, 3).reshape(N, K) # Back to (N, K) + + output = torch.nn.functional.linear(a_fp32, b_fp32, bias) + return output + + +def _int8_gemm_triton_or_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None, out_quant=False): + """ + INT8 matrix multiplication with optional fused bias using Triton kernels or PyTorch fallback. + + Args: + a_int8: INT8 activations, shape (*batch, K) + a_scale: Activation scales, shape (*batch, K//block_size) + b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format + b_scale: Weight scales, shape (N//block_size, K//block_size) + block_size: Block size for quantization + bias: Optional bias vector, shape (N,) + out_quant: If True, return quantized output (INT8 + scales) instead of float + + Returns: + If out_quant=False: Output in float16/float32, shape (*batch, N) + If out_quant=True: Tuple of (output_int8, output_scale) + """ + K = a_int8.shape[-1] + batch_shape = a_int8.shape[:-1] + # b_int8 is weight in (N, K) format (standard PyTorch weight format) + N = b_int8.shape[0] + assert b_int8.shape[1] == K, f"Weight shape mismatch: expected b_int8.shape[1]={K}, got {b_int8.shape[1]}" + + # Try Triton kernel first (only if tensors are on CUDA) + if _HAS_TRITON_INT8 and a_int8.is_cuda: + try: + # int8_gemm/int8_addmm expects: (a, a_s, b, b_s, [bias]) + # a: (*batch, K), a_s: (*batch, K//block_size) + # b: (N, K), b_s: (N//block_size, K//block_size) + # Triton kernels transpose b internally + + # Reshape activations to 2D for int8_gemm + a_2d = a_int8.reshape(-1, K).contiguous() + a_scale_2d = a_scale.reshape(-1, a_scale.shape[-1]).contiguous() + + # Ensure weight tensors are contiguous + b_int8_c = b_int8.contiguous() + b_scale_c = b_scale.contiguous() + + # Call appropriate Triton kernel based on out_quant flag + if out_quant: + # Use fused matmul + quantization kernels + if bias is not None: + # Fused addmm + quantization + output_2d, output_scale_2d = triton_int8_addmm_quant( + a_2d, a_scale_2d, b_int8_c, b_scale_c, bias, out_block_size=block_size + ) + else: + # Fused gemm + quantization + output_2d, output_scale_2d = triton_int8_gemm_quant( + a_2d, a_scale_2d, b_int8_c, b_scale_c, out_block_size=block_size + ) + + # Reshape back to original batch shape + output = output_2d.reshape(*batch_shape, N) + output_scale = output_scale_2d.reshape(*batch_shape, N // block_size) + return output, output_scale + else: + # Standard float output + if bias is not None: + # Use fused addmm kernel + output_2d = triton_int8_addmm(a_2d, a_scale_2d, b_int8_c, b_scale_c, bias) + else: + # Use standard gemm kernel + output_2d = triton_int8_gemm(a_2d, a_scale_2d, b_int8_c, b_scale_c) + + # Reshape back to original batch shape + output = output_2d.reshape(*batch_shape, N) + return output + except Exception as e: + logging.warning(f"Triton int8_gemm/addmm failed: {e}, falling back to PyTorch") + raise e + + # Use PyTorch fallback + fallback_output = _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # If out_quant is requested, quantize the fallback output + if out_quant: + # Use PyTorch activation quantization on the output + from .int8_kernels import act_quant + try: + output_int8, output_scale = act_quant(fallback_output, block_size=block_size) + return output_int8, output_scale + except: + # Fallback to CPU quantization if Triton not available + output_int8, output_scale = BlockWiseINT8Layout._activation_quantize_pytorch( + fallback_output, block_size + ) + return output_int8, output_scale + + return fallback_output + + +@register_layout_op(torch.ops.aten.linear.default, "BlockWiseINT8Layout") +def int8_linear(func, args, kwargs): + """ + Block-wise INT8 linear operation handler with fused Triton kernel support. + + Supports: + - Both quantized input and weight (uses Triton int8_addmm with fused bias) + - Mixed precision (quantized weight, float input) + - Optional quantized output via out_dtype and out_quant parameters + """ + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Case 1: Both input and weight are quantized + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + + # Extract quantized data + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + # Verify configurations + assert not a_is_weight, "Input tensor should not be marked as weight" + assert b_is_weight, "Weight tensor should be marked as weight" + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output + + # Weight is already in (N, K) format (standard PyTorch weight format) + # Pass out_quant to _int8_gemm_triton_or_fallback for fused matmul+quant + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=bias, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Case 2: Fallback - dequantize and use standard linear + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_layout_op(torch.ops.aten.mm.default, "BlockWiseINT8Layout") +def int8_mm(func, args, kwargs): + """Block-wise INT8 matrix multiplication handler with Triton kernel support.""" + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + # Note: For mm, we expect both to be 2D + # If input is marked as weight (2D blocking), we need different logic + # For simplicity, dequantize if configurations don't match expected pattern + if a_is_weight or not b_is_weight: + logging.warning("INT8 mm: Unexpected tensor configurations, falling back to dequantization") + return func(input_tensor.dequantize(), weight.dequantize()) + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output (default: True) + + # Check if weight needs to be transposed to (N, K) format + # For mm: input is (M, K), weight should be (N, K) for the kernel + K = a_int8.shape[-1] + if b_int8.shape[0] == K and b_int8.shape[1] != K: + # Weight is in (K, N) format (transposed), transpose back to (N, K) + b_int8 = b_int8.t().contiguous() + b_scale = b_scale.t().contiguous() + + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=None, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Fallback + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + + +@register_layout_op(torch.ops.aten.addmm.default, "BlockWiseINT8Layout") +def int8_addmm(func, args, kwargs): + """ + Block-wise INT8 addmm operation handler with fused Triton kernel support. + addmm: out = beta * input + alpha * (mat1 @ mat2) + + This uses the fused int8_addmm kernel which combines matmul and bias addition + in a single pass for better performance. + + Args: + args[0]: bias tensor + args[1]: mat1 (input) + args[2]: mat2 (weight) + """ + bias = args[0] + input_tensor = args[1] + weight = args[2] + + # Case 1: Both input and weight are quantized + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + # Extract quantized data + a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight) + + # Verify configurations + assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}" + + orig_dtype = input_tensor._layout_params['orig_dtype'] + out_dtype = kwargs.get('out_dtype', orig_dtype) + out_quant = kwargs.get('out_quant', False) # Whether to return quantized output + + # PyTorch's F.linear internally calls addmm(bias, input, weight.t()) + # So weight arrives in (K, N) format (transposed), need to transpose back to (N, K) + # Check if weight is transposed by comparing dimensions with input + K = a_int8.shape[-1] + if b_is_weight and b_int8.shape[0] == K: + # Weight is in (K, N) format (transposed), transpose back to (N, K) + # The transpose handler also transposed the scale, so we need to transpose it back too + b_int8 = b_int8.t().contiguous() + b_scale = b_scale.t().contiguous() + + # Use fused Triton kernel (combines matmul + bias + optional quant) + result = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, a_block_size, + bias=bias, out_quant=out_quant + ) + + # Handle quantized vs float output + if out_quant: + # Result is (output_int8, output_scale) tuple + output_int8, output_scale = result + + # Wrap in QuantizedTensor + layout_params = { + 'scale': output_scale, + 'block_size': a_block_size, + 'is_weight': False, + 'orig_dtype': out_dtype + } + return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params) + else: + # Result is float tensor + output = result + # Convert to target dtype if needed + if output.dtype != out_dtype: + output = output.to(out_dtype) + return output + + # Fallback: dequantize and use standard addmm + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + + +@register_layout_op(torch.ops.aten.view.default, "BlockWiseINT8Layout") +def int8_view(func, args, kwargs): + """Handle view operations for INT8 tensors.""" + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + # For view, we need to be careful with block structure + # For safety, we'll allow these ops but note that they might break block alignment + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # Return new QuantizedTensor with same layout params + # Note: This assumes the transformation preserves block structure + return QuantizedTensor(transformed, "BlockWiseINT8Layout", input_tensor._layout_params) + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.t.default, "BlockWiseINT8Layout") +def int8_transpose(func, args, kwargs): + """Handle transpose operations for INT8 tensors.""" + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + # Transpose the quantized data + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # For weight tensors, we need to transpose the scale tensor as well + new_layout_params = input_tensor._layout_params.copy() + if new_layout_params.get('is_weight', False): + # Transpose the scale tensor to match the transposed weight + new_layout_params['scale'] = new_layout_params['scale'].t().contiguous() + + # Return new QuantizedTensor with updated layout params + return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params) + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.transpose.int, "BlockWiseINT8Layout") +def int8_transpose_int(func, args, kwargs): + """ + Handle general transpose operations for INT8 tensors. + + torch.transpose(input, dim0, dim1) swaps two dimensions. + + For BlockWiseINT8Layout: + - Activations: quantized along last dimension, scale shape is (*batch_dims, K//block_size) + If we swap the last dimension, we need to adjust scale handling + - Weights: quantized in 2D blocks (M, N), scale shape is (M//block_size, N//block_size) + If we swap dimensions on a 2D weight, transpose the scale tensor too + """ + input_tensor = args[0] + dim0 = args[1] if len(args) > 1 else kwargs.get('dim0', 0) + dim1 = args[2] if len(args) > 2 else kwargs.get('dim1', 1) + + if isinstance(input_tensor, QuantizedTensor): + # Transpose the quantized data + plain_input = input_tensor._qdata + ar = list(args) + ar[0] = plain_input + transformed = func(*ar, **kwargs) + + # Copy layout params + new_layout_params = input_tensor._layout_params.copy() + is_weight = new_layout_params.get('is_weight', False) + + # Normalize dimensions to positive indices + ndim = plain_input.ndim + if dim0 < 0: + dim0 = ndim + dim0 + if dim1 < 0: + dim1 = ndim + dim1 + + # Handle scale tensor transposition + if is_weight: + # For weight tensors (2D with block-wise quantization in both dims) + # If we're transposing the two dimensions of a 2D tensor, transpose scales too + if ndim == 2 and set([dim0, dim1]) == {0, 1}: + # Transposing a 2D weight tensor (M, N) -> (N, M) + # Scale goes from (M//block_size, N//block_size) -> (N//block_size, M//block_size) + new_layout_params['scale'] = new_layout_params['scale'].t().contiguous() + else: + # For higher dimensional weight tensors or partial transposes, + # we may need more complex scale handling + # For now, log a warning as this is an uncommon case + logging.warning( + f"Transpose on weight tensor with dims ({dim0}, {dim1}) and shape {plain_input.shape}. " + f"Scale tensor may need adjustment for correct behavior." + ) + else: + # For activation tensors, block-wise quantization is along last dimension + # If we're swapping the last dimension, this changes the quantization structure + last_dim = ndim - 1 + if dim0 == last_dim or dim1 == last_dim: + # The last dimension is being moved, which affects quantization blocks + # This is a complex case - for safety, we could: + # 1. Dequantize, transpose, requantize (safest but slower) + # 2. Try to adjust scale tensor (complex, error-prone) + # For now, log a warning and proceed with transposing the scale tensor + # The scale tensor dimensions follow the input dimensions except the last + # which is divided by block_size + + # Determine how to transpose the scale tensor + # Scale shape is (*batch_dims, K//block_size) where K is the last dim of input + # When we transpose input dims, we need to transpose scale dims accordingly + # But the last scale dim always corresponds to the quantization blocks + + # Simple heuristic: if transposing involves last dim and input has 3+ dims, + # we transpose the corresponding scale dimensions + scale = new_layout_params['scale'] + if scale.ndim >= 2: + # Map input dimensions to scale dimensions + # Scale has shape (*batch_dims, K//block_size) + # If input has shape (*batch_dims, K), scale maps batch_dims directly + # and last dim is K//block_size + + # For transpose, if we swap dims d0 and d1 in input: + # - If d1 is last_dim (K), then in scale it's still last (K//block_size) + # - If d0 is last_dim, same applies + # - If neither is last_dim, transpose applies to batch dimensions + + if dim1 == last_dim: + # Swapping some batch dim with the last dim + # In scale, this means swapping that batch dim with last scale dim + scale_dim0 = dim0 # Same batch dimension + scale_dim1 = scale.ndim - 1 # Last dim of scale (K//block_size) + new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous() + elif dim0 == last_dim: + # Swapping last dim with some batch dim + scale_dim0 = scale.ndim - 1 # Last dim of scale + scale_dim1 = dim1 # Same batch dimension + new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous() + else: + # Swapping two batch dimensions (not involving last dim) + # Transpose the same dimensions in scale + new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous() + else: + logging.warning( + f"Transpose involves last dimension but scale tensor has shape {scale.shape}. " + f"Scale tensor may need adjustment." + ) + else: + # Transposing batch dimensions that don't affect the quantized dimension + # Transpose the same dimensions in scale tensor + scale = new_layout_params['scale'] + if scale.ndim > max(dim0, dim1): + new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous() + + # Return new QuantizedTensor with updated layout params + return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params) + + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.gelu.default, "BlockWiseINT8Layout") +def int8_gelu(func, args, kwargs): + """ + Block-wise INT8 GELU activation handler with fused Triton kernel support. + + Supports quantized input -> GELU -> quantized output in a single fused kernel. + This avoids materializing full-precision intermediate results. + """ + input_tensor = args[0] + + # Case 1: Input is quantized - use fused kernel + if isinstance(input_tensor, QuantizedTensor): + # Extract quantized data + qdata, scale, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor) + + orig_dtype = input_tensor._layout_params['orig_dtype'] + + # Determine if we should use Triton kernel + if _HAS_TRITON_INT8 and qdata.is_cuda: + try: + # Import the Triton kernel + from .int8_kernels import int8_gelu as triton_int8_gelu + + # Call fused kernel + output_qdata, output_scale = triton_int8_gelu(qdata, scale, block_size=block_size) + + # Wrap result in QuantizedTensor + layout_params = { + 'scale': output_scale.to(torch.float32), + 'block_size': block_size, + 'is_weight': False, # Output is always activation format + 'orig_dtype': orig_dtype + } + return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", layout_params) + + except Exception as e: + logging.warning(f"Triton int8_gelu failed: {e}, falling back to dequantization") + # Fall through to dequantization fallback + + # Fallback: dequantize, apply GELU, quantize + fp_input = input_tensor.dequantize() + fp_output = torch.nn.functional.gelu(fp_input) + + # Quantize output + output_qdata, output_layout_params = BlockWiseINT8Layout.quantize( + fp_output, + block_size=block_size, + is_weight=False + ) + output_layout_params['orig_dtype'] = orig_dtype + + return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", output_layout_params) + + # Case 2: Input is not quantized - use standard GELU + return func(*args, **kwargs) + + +@register_layout_op(torch.ops.aten.add_.Tensor, "BlockWiseINT8Layout") +def int8_add_(func, args, kwargs): + """ + Block-wise INT8 in-place addition handler for LoRA application. + + This operation is typically used when applying LoRA to weight matrices. + Since speed is not critical for this operation: + - If target is a weight: dequantize, add, then requantize as weight + - Otherwise: dequantize and fallback to regular addition + + Args: + args[0]: Target tensor (self) to be modified in-place + args[1]: Tensor to add + """ + target = args[0] + + if isinstance(target, QuantizedTensor): + # Extract quantization parameters + _, _, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(target) + + # Only handle the weight case specially + if is_weight: + other = args[1] + orig_dtype = target._layout_params['orig_dtype'] + + # Dequantize target + target_fp = target.dequantize() + + # Dequantize other if it's also quantized + if isinstance(other, QuantizedTensor): + other_fp = other.dequantize() + else: + other_fp = other + + # Perform addition + result_fp = target_fp + other_fp + + # Requantize as weight + result_qdata, result_layout_params = BlockWiseINT8Layout.quantize( + result_fp, + block_size=block_size, + is_weight=True + ) + result_layout_params['orig_dtype'] = orig_dtype + + # Update target in-place by copying the new quantized data + target._qdata.copy_(result_qdata) + target._layout_params['scale'].copy_(result_layout_params['scale']) + return target + + # For non-weight tensors or non-quantized tensors, use standard fallback + return QuantizedTensor._dequant_and_fallback(func, args, kwargs) + + +@register_layout_op(torch.ops.aten.to.dtype, "BlockWiseINT8Layout") +def int8_to_dtype(func, args, kwargs): + """ + Block-wise INT8 dtype conversion handler. + + This operation handles .to(dtype) calls on quantized tensors. + - If converting to torch.int8, do nothing (already in INT8 format) + - Otherwise, dequantize and fallback + + Args: + args[0]: Input tensor + args[1]: Target dtype + """ + input_tensor = args[0] + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype', None) + + if isinstance(input_tensor, QuantizedTensor): + # If target dtype is int8, the tensor is already in INT8 format + if target_dtype == torch.int8: + # No conversion needed, return as-is + return input_tensor + + # For any other dtype or non-quantized tensors, use standard fallback + return QuantizedTensor._dequant_and_fallback(func, args, kwargs) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd46be..e6389ce6b1c0 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, weight_decompose +from comfy.quant_ops import QuantizedTensor class BOFTAdapter(WeightAdapterBase): @@ -109,7 +110,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function((strength * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5845..60d66a71ec96 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, weight_decompose +from comfy.quant_ops import QuantizedTensor class GLoRAAdapter(WeightAdapterBase): @@ -87,7 +88,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d4033fd..7c6a34bbc9bf 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -4,7 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose - +from comfy.quant_ops import QuantizedTensor class HadaWeight(torch.autograd.Function): @staticmethod @@ -226,7 +226,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7f86..11f52ac282f2 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -9,6 +9,7 @@ weight_decompose, factorization, ) +from comfy.quant_ops import QuantizedTensor class LokrDiff(WeightAdapterTrainBase): @@ -214,7 +215,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b754..7ba244d56fe6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -10,6 +10,7 @@ pad_tensor_to_shape, tucker_weight_from_conv, ) +from comfy.quant_ops import QuantizedTensor class LoraDiff(WeightAdapterTrainBase): @@ -206,7 +207,7 @@ def calculate_weight( function, ) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function(((strength * alpha) * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab96353ca..a67e35a6abe3 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from comfy.quant_ops import QuantizedTensor class OFTDiff(WeightAdapterTrainBase): @@ -155,7 +156,7 @@ def calculate_weight( if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function((strength * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype if not isinstance(weight, QuantizedTensor) else torch.float32)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8026..f2689a7408cc 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -2,6 +2,8 @@ import torch import sys import os +import time +import gc # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -13,8 +15,16 @@ def has_gpu(): if not has_gpu(): args.cpu = True -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import ( + QuantizedTensor, + TensorCoreFP8Layout, + BlockWiseINT8Layout, + _int8_gemm_pytorch_fallback, + _int8_gemm_triton_or_fallback +) +# set TRITON_SKIP_AUTOTUNING=1 to skip autotuning +os.environ['TRITON_SKIP_AUTOTUNING'] = '1' class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" @@ -186,5 +196,2516 @@ def test_unsupported_op_dequantizes(self): self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") +class TestBlockWiseINT8Layout(unittest.TestCase): + """Test the BlockWiseINT8Layout implementation""" + + def test_weight_quantize_dequantize(self): + """Test weight quantization and dequantization""" + # Create a weight tensor (M, N) with dimensions divisible by 128 + weight = torch.randn(256, 512, dtype=torch.float32) + block_size = 128 + + # Quantize as weight + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Check quantized data + self.assertEqual(qdata.dtype, torch.int8) + self.assertEqual(qdata.shape, weight.shape) + + # Check scale shape: (M//block_size, N//block_size) + expected_scale_shape = (256 // block_size, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + self.assertEqual(layout_params['block_size'], block_size) + self.assertTrue(layout_params['is_weight']) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + # Dequantize + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Check reconstruction quality + self.assertEqual(dequantized.dtype, torch.float32) + self.assertEqual(dequantized.shape, weight.shape) + + # INT8 has limited precision, so we use a relaxed tolerance + max_error = (dequantized - weight).abs().max() + mean_error = (dequantized - weight).abs().mean() + self.assertLess(mean_error, 0.1) # Mean error should be reasonable for INT8 + + def test_activation_quantize_dequantize(self): + """Test activation quantization and dequantization""" + # Create an activation tensor with batch dimensions + activation = torch.randn(4, 16, 512, dtype=torch.float32) + block_size = 128 + + # Quantize as activation + qdata, layout_params = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Check quantized data + self.assertEqual(qdata.dtype, torch.int8) + self.assertEqual(qdata.shape, activation.shape) + + # Check scale shape: (*batch_dims, K//block_size) + expected_scale_shape = (4, 16, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + self.assertEqual(layout_params['block_size'], block_size) + self.assertFalse(layout_params['is_weight']) + + # Dequantize + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Check reconstruction + self.assertEqual(dequantized.shape, activation.shape) + mean_error = (dequantized - activation).abs().mean() + self.assertLess(mean_error, 0.1) + + def test_quantized_tensor_creation(self): + """Test creating QuantizedTensor with BlockWiseINT8Layout""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.int8) + self.assertEqual(qt.shape, weight.shape) + self.assertEqual(qt._layout_type, "BlockWiseINT8Layout") + + # Test dequantization + dequantized = qt.dequantize() + self.assertEqual(dequantized.dtype, torch.float32) + mean_error = (dequantized - weight).abs().mean() + self.assertLess(mean_error, 0.1) + + +class TestBlockWiseINT8Operations(unittest.TestCase): + """Test operations with BlockWiseINT8 quantized tensors""" + + def test_linear_operation(self): + """Test linear operation with quantized weight and activation""" + torch.manual_seed(42) + + # Create test data + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Input activation + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32) + + # Weight (note: linear expects weight as (out_features, in_features)) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32) + bias = torch.randn(out_features, dtype=torch.float32) + + # Quantize both + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Compute quantized linear + output_q = torch.nn.functional.linear(input_q, weight_q, bias) + + # Compute reference (full precision) + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare results + self.assertEqual(output_q.shape, output_ref.shape) + + # INT8 quantization introduces error, but should be reasonable + mean_rel_error = ((output_q - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.2) # 20% relative error tolerance + + def test_clone_operation(self): + """Test clone operation on INT8 quantized tensor""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, "BlockWiseINT8Layout") + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_detach_operation(self): + """Test detach operation on INT8 quantized tensor""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, "BlockWiseINT8Layout") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_device_transfer(self): + """Test moving INT8 quantized tensor to different devices""" + weight = torch.randn(256, 512, dtype=torch.float32) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + # Move to CPU (should be no-op if already on CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + def test_mixed_precision_fallback(self): + """Test mixed precision: quantized weight with float input""" + torch.manual_seed(42) + + input_fp32 = torch.randn(4, 256, dtype=torch.float32) + weight_fp32 = torch.randn(512, 256, dtype=torch.float32) + + # Only quantize weight + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + # Linear with float input and quantized weight + output = torch.nn.functional.linear(input_fp32, weight_q) + + # Should work via fallback + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32) + + # With mixed precision fallback (dequantize weight), error should be small + mean_error = (output - output_ref).abs().mean() + self.assertLess(mean_error, 0.3) + + +class TestBlockWiseINT8EdgeCases(unittest.TestCase): + """Test edge cases and error handling for INT8 quantization""" + + def test_dimension_alignment(self): + """Test that dimensions must be divisible by block_size""" + # Try to quantize with misaligned dimensions + weight = torch.randn(200, 300, dtype=torch.float32) # Not divisible by 128 + + with self.assertRaises(AssertionError): + BlockWiseINT8Layout.quantize(weight, block_size=128, is_weight=True) + + def test_weight_must_be_2d(self): + """Test that weight quantization requires 2D tensors""" + weight_3d = torch.randn(4, 256, 512, dtype=torch.float32) + + with self.assertRaises(AssertionError): + BlockWiseINT8Layout.quantize(weight_3d, block_size=128, is_weight=True) + + def test_different_block_sizes(self): + """Test quantization with different block sizes""" + for block_size in [64, 128, 256]: + weight = torch.randn(512, 512, dtype=torch.float32) + + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + expected_scale_shape = (512 // block_size, 512 // block_size) + self.assertEqual(layout_params['scale'].shape, expected_scale_shape) + + # Verify dequantization works + dequantized = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + self.assertEqual(dequantized.shape, weight.shape) + + +class TestBlockWiseINT8Precision(unittest.TestCase): + """Precision tests for BlockWiseINT8Layout operations""" + + def test_weight_quantization_matches_manual_calculation(self): + """Test that weight quantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + M, N = 256, 512 + block_size = 128 + weight = torch.randn(M, N, dtype=torch.float32, device=device) + + # Manual PyTorch calculation for weight quantization + # Weight shape: (M, N), blocks: (M//block_size, N//block_size) + weight_reshaped = weight.reshape(M // block_size, block_size, N // block_size, block_size) + weight_blocks = weight_reshaped.permute(0, 2, 1, 3) # (M//bs, N//bs, bs, bs) + + # Calculate scale per block: amax / 127.0 + amax = weight_blocks.abs().amax(dim=(2, 3), keepdim=False) # (M//bs, N//bs) + scale_manual = amax / 127.0 + scale_manual = torch.maximum(scale_manual, torch.tensor(1e-8, device=device, dtype=weight.dtype)) + + # Quantize: divide by scale and clamp to [-127, 127] + weight_blocks_scaled = weight_blocks / scale_manual.unsqueeze(-1).unsqueeze(-1) + int8_manual = torch.clamp(weight_blocks_scaled, -127.0, 127.0).to(torch.int8) + int8_manual = int8_manual.permute(0, 2, 1, 3).reshape(M, N) + + # Use BlockWiseINT8Layout.quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Compare int8 values + self.assertEqual(qdata.shape, int8_manual.shape) + self.assertEqual(qdata.dtype, torch.int8) + matches = (qdata == int8_manual).float().mean().item() + self.assertGreater(matches, 0.95, f"Only {matches*100:.2f}% of int8 values match") + + # Compare scales + self.assertEqual(layout_params['scale'].shape, scale_manual.shape) + scale_diff = (layout_params['scale'] - scale_manual).abs().mean().item() + scale_rel_diff = (scale_diff / (scale_manual.abs().mean().item() + 1e-8)) + self.assertLess(scale_rel_diff, 0.01, f"Scale relative difference too high: {scale_rel_diff}") + + def test_activation_quantization_matches_manual_calculation(self): + """Test that activation quantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + K = 512 + block_size = 128 + activation = torch.randn(batch_size, seq_len, K, dtype=torch.float32, device=device) + + # Manual PyTorch calculation for activation quantization + # Activation shape: (*batch_dims, K), scale shape: (*batch_dims, K//block_size) + orig_shape = activation.shape + batch_dims = orig_shape[:-1] + + # Reshape to expose blocks in last dimension + activation_reshaped = activation.reshape(*batch_dims, K // block_size, block_size) + + # Calculate scale per block: amax / 127.0 + amax = activation_reshaped.abs().amax(dim=-1, keepdim=False) # (*batch_dims, K//block_size) + scale_manual = amax / 127.0 + scale_manual = torch.maximum(scale_manual, torch.tensor(1e-8, device=device, dtype=activation.dtype)) + + # Quantize: divide by scale and clamp to [-127, 127] + activation_scaled = activation_reshaped / scale_manual.unsqueeze(-1) + int8_manual = torch.clamp(activation_scaled, -127.0, 127.0).to(torch.int8) + int8_manual = int8_manual.reshape(orig_shape) + + # Use BlockWiseINT8Layout.quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Compare int8 values + self.assertEqual(qdata.shape, int8_manual.shape) + self.assertEqual(qdata.dtype, torch.int8) + matches = (qdata == int8_manual).float().mean().item() + self.assertGreater(matches, 0.95, f"Only {matches*100:.2f}% of int8 values match") + + # Compare scales + self.assertEqual(layout_params['scale'].shape, scale_manual.shape) + scale_diff = (layout_params['scale'] - scale_manual).abs().mean().item() + scale_rel_diff = (scale_diff / (scale_manual.abs().mean().item() + 1e-8)) + self.assertLess(scale_rel_diff, 0.01, f"Scale relative difference too high: {scale_rel_diff}") + + def test_dequantization_matches_manual_calculation(self): + """Test that dequantization matches manual PyTorch calculation""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + # Test weight dequantization + M, N = 256, 512 + block_size = 128 + weight = torch.randn(M, N, dtype=torch.float32, device=device) + + # Quantize + qdata, layout_params = BlockWiseINT8Layout.quantize( + weight, + block_size=block_size, + is_weight=True + ) + + # Manual dequantization for weight + scale = layout_params['scale'] # (M//bs, N//bs) + int8_data = qdata # (M, N) + orig_dtype = layout_params['orig_dtype'] + + # Reshape to blocks + int8_reshaped = int8_data.reshape(M // block_size, block_size, N // block_size, block_size) + int8_blocks = int8_reshaped.permute(0, 2, 1, 3) # (M//bs, N//bs, bs, bs) + + # Dequantize: int8 * scale (no division by 127) + fp_blocks = int8_blocks.to(orig_dtype) * scale.unsqueeze(-1).unsqueeze(-1) + dequant_manual = fp_blocks.permute(0, 2, 1, 3).reshape(M, N) + + # Use BlockWiseINT8Layout.dequantize + dequant_layout = BlockWiseINT8Layout.dequantize(qdata, **layout_params) + + # Compare + diff = (dequant_layout - dequant_manual).abs().max().item() + self.assertLess(diff, 1e-5, f"Dequantization differs by {diff}") + + # Test activation dequantization + batch_size = 4 + seq_len = 16 + K = 512 + activation = torch.randn(batch_size, seq_len, K, dtype=torch.float32, device=device) + + qdata_act, layout_params_act = BlockWiseINT8Layout.quantize( + activation, + block_size=block_size, + is_weight=False + ) + + # Manual dequantization for activation + scale_act = layout_params_act['scale'] # (batch_size, seq_len, K//bs) + int8_data_act = qdata_act # (batch_size, seq_len, K) + orig_dtype_act = layout_params_act['orig_dtype'] + + # Reshape + int8_reshaped_act = int8_data_act.reshape(batch_size, seq_len, K // block_size, block_size) + # Dequantize: int8 * scale (no division by 127) + fp_blocks_act = int8_reshaped_act.to(orig_dtype_act) * scale_act.unsqueeze(-1) + dequant_manual_act = fp_blocks_act.reshape(batch_size, seq_len, K) + + # Use BlockWiseINT8Layout.dequantize + dequant_layout_act = BlockWiseINT8Layout.dequantize(qdata_act, **layout_params_act) + + # Compare + diff_act = (dequant_layout_act - dequant_manual_act).abs().max().item() + self.assertLess(diff_act, 1e-5, f"Activation dequantization differs by {diff_act}") + + def test_triton_linear_matches_pytorch_fallback(self): + """Test that Triton kernel INT8 GEMM matches PyTorch INT8 GEMM fallback""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 1024 + block_size = 128 + + # Create original float tensors + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize to get int8 data and scales + input_q = QuantizedTensor.from_float( + input_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Extract int8 data and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Call Triton/fallback version (will use Triton on GPU if available) + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Call PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Convert both to float32 for fair comparison (Triton outputs float16, PyTorch outputs float32) + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + + # These should match very closely (same int8 inputs, same computation) + abs_diff = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff = abs_diff.mean().item() + max_abs_diff = abs_diff.max().item() + + # Use relative error to account for float16 precision limits + rel_diff = abs_diff / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff = rel_diff.mean().item() + + # Since both compute the same INT8 GEMM from same inputs, differences should be tiny + self.assertLess(mean_rel_diff, 1e-3, + f"Triton and PyTorch INT8 GEMM differ too much: mean_rel={mean_rel_diff:.6f}, mean_abs={mean_abs_diff:.6f}, max={max_abs_diff:.6f}") + + def test_triton_linear_from_raw_int8_and_scales(self): + """Test INT8 GEMM from manually created int8 data and scales - compare 3 methods""" + device = torch.device('cuda' if has_gpu() else 'cpu') + if not has_gpu(): + self.skipTest("This test requires GPU (Triton kernels)") + torch.manual_seed(123) + + batch_size = 2 + seq_len = 8 + in_features = 256 + out_features = 512 + block_size = 128 + + # Manually create int8 data and scales for input (activation) + # Input shape: (batch_size, seq_len, in_features) + input_int8 = torch.randint(-127, 127, (batch_size, seq_len, in_features), + dtype=torch.int8, device=device) + input_scale = torch.rand(batch_size, seq_len, in_features // block_size, + dtype=torch.float32, device=device) * 0.1 + + input_layout_params = { + 'scale': input_scale, + 'block_size': block_size, + 'is_weight': False, + 'orig_dtype': torch.float32 + } + input_q = QuantizedTensor(input_int8, "BlockWiseINT8Layout", input_layout_params) + + # Manually create int8 data and scales for weight + # Weight shape: (out_features, in_features) + weight_int8 = torch.randint(-127, 127, (out_features, in_features), + dtype=torch.int8, device=device) + weight_scale = torch.rand(out_features // block_size, in_features // block_size, + dtype=torch.float32, device=device) * 0.1 + + weight_layout_params = { + 'scale': weight_scale, + 'block_size': block_size, + 'is_weight': True, + 'orig_dtype': torch.float32 + } + weight_q = QuantizedTensor(weight_int8, "BlockWiseINT8Layout", weight_layout_params) + + # Bias + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Method 1: Call INT8 GEMM via Triton/fallback + output_triton = _int8_gemm_triton_or_fallback( + input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias, out_quant=False + ) + + # Method 2: Call PyTorch INT8 GEMM fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + input_int8, input_scale, weight_int8, weight_scale, block_size, bias=bias + ) + + # Method 3: Dequantize and use standard torch.nn.functional.linear + input_dequant = input_q.dequantize() + weight_dequant = weight_q.dequantize() + output_dequant = torch.nn.functional.linear(input_dequant, weight_dequant, bias) + + # Convert all to float32 for fair comparison + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + output_dequant_fp32 = output_dequant.to(torch.float32) + + # Compare Method 1 vs Method 2: Triton vs PyTorch INT8 GEMM + self.assertEqual(output_triton.shape, output_pytorch.shape) + abs_diff_12 = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff_12 = abs_diff_12.mean().item() + max_abs_diff_12 = abs_diff_12.max().item() + + # Use relative error since Triton outputs float16 which has limited precision for large values + rel_diff_12 = abs_diff_12 / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff_12 = rel_diff_12.mean().item() + + # Same int8 data → both INT8 GEMMs should produce nearly identical results + # Use 0.1% relative error tolerance to account for float16 precision limits + self.assertLess(mean_rel_diff_12, 1e-3, + f"Triton and PyTorch INT8 GEMM differ: mean_rel={mean_rel_diff_12:.6f}, mean_abs={mean_abs_diff_12:.6f}, max_abs={max_abs_diff_12:.6f}") + + # Compare Method 1 vs Method 3: Triton INT8 GEMM vs Dequant+Float Linear + self.assertEqual(output_triton.shape, output_dequant.shape) + abs_diff_13 = (output_triton_fp32 - output_dequant_fp32).abs() + mean_abs_diff_13 = abs_diff_13.mean().item() + max_abs_diff_13 = abs_diff_13.max().item() + + # Use relative error for float16 precision limits + rel_diff_13 = abs_diff_13 / (output_dequant_fp32.abs() + 1e-6) + mean_rel_diff_13 = rel_diff_13.mean().item() + + # INT8 GEMM should match dequant+float linear (both compute the same thing) + self.assertLess(mean_rel_diff_13, 1e-3, + f"Triton INT8 GEMM and dequant+float differ: mean_rel={mean_rel_diff_13:.6f}, mean_abs={mean_abs_diff_13:.6f}, max_abs={max_abs_diff_13:.6f}") + + # Compare Method 2 vs Method 3: PyTorch INT8 GEMM vs Dequant+Float Linear + abs_diff_23 = (output_pytorch_fp32 - output_dequant_fp32).abs() + mean_abs_diff_23 = abs_diff_23.mean().item() + max_abs_diff_23 = abs_diff_23.max().item() + + # Use relative error + rel_diff_23 = abs_diff_23 / (output_dequant_fp32.abs() + 1e-6) + mean_rel_diff_23 = rel_diff_23.mean().item() + + # PyTorch INT8 GEMM should also match dequant+float linear + self.assertLess(mean_rel_diff_23, 1e-3, + f"PyTorch INT8 GEMM and dequant+float differ: mean_rel={mean_rel_diff_23:.6f}, mean_abs={mean_abs_diff_23:.6f}, max_abs={max_abs_diff_23:.6f}") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_triton_vs_pytorch_linear_implementation(self): + """Compare Triton kernel vs PyTorch fallback implementation directly""" + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 8 + seq_len = 32 + in_features = 1024 + out_features = 2048 + block_size = 128 + + # Create test data + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize + input_q = QuantizedTensor.from_float(input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False) + weight_q = QuantizedTensor.from_float(weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True) + + # Extract quantized data + a_int8, a_scale, a_block_size, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, b_block_size, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Call Triton version (via _int8_gemm_triton_or_fallback) + # Note: This may still use Triton for quant fusion even with out_quant=False + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Call PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Compare Triton vs PyTorch fallback implementations + triton_pytorch_diff = (output_triton - output_pytorch).abs().mean().item() + + # These should match very closely since both compute the same operation + self.assertLess(triton_pytorch_diff, 1e-2, + f"Triton and PyTorch implementations differ: {triton_pytorch_diff}") + + # Also test via high-level API (which may return quantized output) + output_api = torch.nn.functional.linear(input_q, weight_q, bias) + if isinstance(output_api, QuantizedTensor): + output_api_dequant = output_api.dequantize() + else: + output_api_dequant = output_api + + # Compare API with PyTorch fallback (more lenient since API might use different path) + api_pytorch_diff = (output_api_dequant - output_pytorch).abs().mean().item() + self.assertLess(api_pytorch_diff, 0.5, + f"API and PyTorch implementations differ: {api_pytorch_diff}") + + def test_int8_gemm_with_block_size_128(self): + """Test INT8 GEMM with block_size=128 (standard size for Triton kernels)""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 512 + block_size = 128 + + # Create test data + input_fp = torch.randn(batch_size, seq_len, in_features, + dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, + dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize to get int8 data + input_q = QuantizedTensor.from_float( + input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False + ) + weight_q = QuantizedTensor.from_float( + weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True + ) + + # Extract int8 and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run Triton/fallback INT8 GEMM + output_triton = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Run PyTorch INT8 GEMM fallback + output_pytorch = _int8_gemm_pytorch_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias + ) + + # Convert both to float32 for fair comparison (Triton outputs float16, PyTorch outputs float32) + output_triton_fp32 = output_triton.to(torch.float32) + output_pytorch_fp32 = output_pytorch.to(torch.float32) + + # Compare using relative error + abs_diff = (output_triton_fp32 - output_pytorch_fp32).abs() + mean_abs_diff = abs_diff.mean().item() + rel_diff = abs_diff / (output_pytorch_fp32.abs() + 1e-6) + mean_rel_diff = rel_diff.mean().item() + + self.assertLess(mean_rel_diff, 1e-3, + f"Triton and PyTorch INT8 GEMM differ: mean_rel={mean_rel_diff:.6f}, mean_abs={mean_abs_diff:.6f}") + + def test_end_to_end_quantization_accuracy(self): + """Test end-to-end: quantize → INT8 GEMM → output accuracy vs float baseline""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 512 + out_features = 1024 + block_size = 128 + + # Create float tensors + input_fp = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Float baseline + output_float = torch.nn.functional.linear(input_fp, weight_fp, bias) + + # Quantize → INT8 GEMM path + input_q = QuantizedTensor.from_float(input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False) + weight_q = QuantizedTensor.from_float(weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True) + + # Get int8 data and scales + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run INT8 GEMM + output_int8 = _int8_gemm_triton_or_fallback( + a_int8, a_scale, b_int8, b_scale, block_size, bias=bias, out_quant=False + ) + + # Convert to float32 for fair comparison (Triton outputs float16) + output_int8_fp32 = output_int8.to(torch.float32) + output_float_fp32 = output_float.to(torch.float32) + + # Compare with float baseline + abs_error = (output_int8_fp32 - output_float_fp32).abs() + mean_abs_error = abs_error.mean().item() + rel_error = abs_error / (output_float_fp32.abs() + 1e-6) + mean_rel_error = rel_error.mean().item() + + # This error is from quantization, not from INT8 GEMM implementation + # INT8 quantization can have ~5-20% relative error depending on data distribution + self.assertLess(mean_rel_error, 0.25, + f"Quantization error too high: {mean_rel_error:.4f}") + + def test_basic_weight_quantization(self): + """Test basic weight quantization precision""" + device = torch.device('cuda' if has_gpu() else 'cpu') + weight = torch.randn(256, 512, dtype=torch.float32, device=device) + + qt = QuantizedTensor.from_float( + weight, + "BlockWiseINT8Layout", + block_size=128, + is_weight=True + ) + + self.assertEqual(qt.shape, weight.shape) + self.assertEqual(qt.dtype, torch.int8) + + dequantized = qt.dequantize() + error = (dequantized - weight).abs().mean() + self.assertLess(error, 0.1, "Mean reconstruction error too high") + + def test_large_activation_quantization(self): + """Test activation quantization with larger tensor""" + device = torch.device('cuda' if has_gpu() else 'cpu') + activation = torch.randn(16, 128, 4096, dtype=torch.float32, device=device) + + qt = QuantizedTensor.from_float( + activation, + "BlockWiseINT8Layout", + block_size=128, + is_weight=False + ) + + self.assertEqual(qt.shape, activation.shape) + self.assertEqual(qt.dtype, torch.int8) + + dequantized = qt.dequantize() + error = (dequantized - activation).abs().mean() + self.assertLess(error, 0.1, "Mean reconstruction error too high") + + def test_quantized_linear_precision(self): + """Test quantized linear operation precision""" + torch.manual_seed(42) + device = torch.device('cuda' if has_gpu() else 'cpu') + + batch_size = 16 + seq_len = 128 + in_features = 2048 + out_features = 2048 + block_size = 128 + + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize both + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Compute quantized linear (returns QuantizedTensor by default) + output_q = torch.nn.functional.linear(input_q, weight_q, bias) + output_q = QuantizedTensor.from_float(output_q, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + self.assertIsInstance(output_q, QuantizedTensor, "Default output should be QuantizedTensor") + + # Dequantize for comparison + output_dequant = output_q.dequantize() + + # Compute reference + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + self.assertEqual(output_dequant.shape, output_ref.shape) + + mean_rel_error = ((output_dequant - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.2, "Mean relative error too high") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_triton_vs_pytorch_precision(self): + """Compare Triton kernel vs PyTorch fallback precision""" + # Check if Triton is available + try: + from comfy.int8_kernels import int8_gemm as triton_int8_gemm + has_triton = True + except ImportError: + self.skipTest("Triton kernels not available") + + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize + input_q = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Extract quantized data + a_int8, a_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(input_q) + b_int8, b_scale, _, _ = BlockWiseINT8Layout.get_plain_tensors(weight_q) + + # Run Triton version (via _int8_gemm_triton_or_fallback) + output_triton = _int8_gemm_triton_or_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # Run PyTorch fallback directly + output_pytorch = _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias) + + # Compute reference + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare errors + error_triton = ((output_triton - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_pytorch = ((output_pytorch - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_between = (output_triton - output_pytorch).abs().mean() + + self.assertLess(error_triton, 0.2, "Triton error too high") + self.assertLess(error_pytorch, 0.2, "PyTorch error too high") + self.assertLess(error_between, 4e-3, "Triton and PyTorch implementations differ") + + # Test via high-level API (torch dispatch) + output_dispatch = torch.nn.functional.linear(input_q, weight_q, bias) + + # Dequantize if needed + if isinstance(output_dispatch, QuantizedTensor): + output_dispatch_fp32 = output_dispatch.dequantize() + else: + output_dispatch_fp32 = output_dispatch + + # Compare with reference + error_dispatch = ((output_dispatch_fp32 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + self.assertLess(error_dispatch, 0.2, "Torch dispatch error too high") + + # Compare dispatch output with low-level Triton output + error_dispatch_vs_triton = (output_dispatch_fp32 - output_triton).abs().mean() + self.assertLess(error_dispatch_vs_triton, 0.2, "Dispatch differs from low-level implementation") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_vs_fp8_precision(self): + """Compare INT8 vs FP8 precision""" + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + except (RuntimeError, AttributeError): + self.skipTest("FP8 dtypes not supported on this system") + + torch.manual_seed(42) + device = torch.device('cuda') + + batch_size = 16 + seq_len = 128 + in_features = 2048 + out_features = 2048 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_int8 = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Quantize with FP8 + input_fp8 = QuantizedTensor.from_float( + input_fp32, + "TensorCoreFP8Layout", + dtype=torch.float8_e4m3fn + ) + + weight_fp8 = QuantizedTensor.from_float( + weight_fp32, + "TensorCoreFP8Layout", + dtype=torch.float8_e4m3fn + ) + + # Compute outputs + output_int8_q = torch.nn.functional.linear(input_int8, weight_int8, bias) + output_int8 = output_int8_q.dequantize() if isinstance(output_int8_q, QuantizedTensor) else output_int8_q + + # FP8 doesn't support fused bias, so add it manually + output_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + output_fp8 = output_fp8 + bias + if isinstance(output_fp8, QuantizedTensor): + output_fp8 = output_fp8.dequantize() + + output_ref = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + + # Compare precision + error_int8 = ((output_int8 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_fp8 = ((output_fp8 - output_ref).abs() / (output_ref.abs() + 1e-6)).mean() + error_between = (output_int8 - output_fp8).abs().mean() + + self.assertLess(error_int8, 0.2, "INT8 error too high") + self.assertLess(error_fp8, 0.4, "FP8 error too high") + + # Memory usage comparison + int8_memory = input_int8._qdata.element_size() * input_int8._qdata.numel() + \ + weight_int8._qdata.element_size() * weight_int8._qdata.numel() + fp8_memory = input_fp8._qdata.element_size() * input_fp8._qdata.numel() + \ + weight_fp8._qdata.element_size() * weight_fp8._qdata.numel() + fp32_memory = input_fp32.element_size() * input_fp32.numel() + \ + weight_fp32.element_size() * weight_fp32.numel() + + self.assertLess(int8_memory, fp32_memory, "INT8 should use less memory than FP32") + self.assertLess(fp8_memory, fp32_memory, "FP8 should use less memory than FP32") + + def test_output_types(self): + """Test output types for all registered operations""" + device = torch.device('cuda' if has_gpu() else 'cpu') + torch.manual_seed(42) + + batch_size = 4 + seq_len = 16 + in_features = 256 + out_features = 512 + block_size = 128 + + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float( + input_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + weight_int8 = QuantizedTensor.from_float( + weight_fp32, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Test 1: linear with quantized output (default) + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + output = QuantizedTensor.from_float(output, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output, QuantizedTensor, "Default output should be QuantizedTensor") + self.assertEqual(output.layout_type, "BlockWiseINT8Layout") + + # Test 2: linear with explicit dequantization + output_q = torch.nn.functional.linear(input_int8, weight_int8, bias) + output_reg = output_q.dequantize() + self.assertNotIsInstance(output_reg, QuantizedTensor, "Dequantized output should be regular tensor") + + # Test 3: mm operation (2D input) - default quantized output + input_2d = input_fp32.reshape(-1, in_features) + input_int8_2d = QuantizedTensor.from_float(input_2d, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8_t = weight_int8.t() + + output_mm = torch.mm(input_int8_2d, weight_int8_t) + output_mm = QuantizedTensor.from_float(output_mm, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_mm, QuantizedTensor, "Default mm output should be QuantizedTensor") + self.assertEqual(output_mm.layout_type, "BlockWiseINT8Layout") + + # Test 4: addmm operation - default quantized output + output_addmm = torch.addmm(bias, input_int8_2d, weight_int8_t) + output_addmm = QuantizedTensor.from_float(output_addmm, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_addmm, QuantizedTensor, "Default addmm output should be QuantizedTensor") + self.assertEqual(output_addmm.layout_type, "BlockWiseINT8Layout") + + # Test 5: view operation preserves quantization + view_result = input_int8.view(batch_size * seq_len, in_features) + self.assertIsInstance(view_result, QuantizedTensor, "view should preserve QuantizedTensor") + self.assertEqual(view_result.layout_type, "BlockWiseINT8Layout") + + # Test 6: transpose operation preserves quantization + transpose_result = weight_int8.t() + self.assertIsInstance(transpose_result, QuantizedTensor, "transpose should preserve QuantizedTensor") + self.assertEqual(transpose_result.layout_type, "BlockWiseINT8Layout") + + # Test 7: clone operation preserves quantization + clone_result = input_int8.clone() + self.assertIsInstance(clone_result, QuantizedTensor, "clone should preserve QuantizedTensor") + self.assertEqual(clone_result.layout_type, "BlockWiseINT8Layout") + + # Test 8: detach operation preserves quantization + detach_result = input_int8.detach() + self.assertIsInstance(detach_result, QuantizedTensor, "detach should preserve QuantizedTensor") + self.assertEqual(detach_result.layout_type, "BlockWiseINT8Layout") + + +class TestBlockWiseINT8GELU(unittest.TestCase): + """Test INT8 block-wise GELU activation""" + + def test_int8_gelu_basic(self): + """Test basic GELU operation with INT8 quantized tensors""" + device = torch.device('cuda' if has_gpu() else 'cpu') + + batch_size = 2 + seq_len = 512 + hidden_dim = 2048 + block_size = 128 + + # Create random input tensor + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float16, device=device) + + # Compute reference output (full precision) + with torch.no_grad(): + reference_output = torch.nn.functional.gelu(x) + + # Quantize input + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + # Apply GELU (should use fused kernel) + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + self.assertEqual(output_fp.shape, reference_output.shape) + + # Compute error metrics + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high: {relative_error}") + + def test_int8_gelu_2d(self): + """Test GELU with 2D tensors""" + device = torch.device('cuda' if has_gpu() else 'cpu') + + M, N = 256, 2048 + block_size = 128 + + x = torch.randn(M, N, dtype=torch.float16, device=device) + reference_output = torch.nn.functional.gelu(x) + + # Quantize and apply GELU + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high: {relative_error}") + + def test_int8_gelu_different_shapes(self): + """Test GELU with various tensor shapes""" + device = torch.device('cuda' if has_gpu() else 'cpu') + block_size = 128 + + test_shapes = [ + (128, 1024), # 2D + (4, 512, 2048), # 3D + (2, 8, 128, 1024), # 4D + ] + + for shape in test_shapes: + with self.subTest(shape=shape): + x = torch.randn(*shape, dtype=torch.float16, device=device) + reference_output = torch.nn.functional.gelu(x) + + # Quantize and apply GELU + x_quant = QuantizedTensor.from_float(x, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + + with torch.no_grad(): + output_quant = torch.nn.functional.gelu(x_quant) + + if isinstance(output_quant, QuantizedTensor): + output_fp = output_quant.dequantize() + else: + output_fp = output_quant + + relative_error = (torch.norm(output_fp - reference_output) / torch.norm(reference_output)).item() + self.assertLess(relative_error, 0.1, f"Relative error too high for shape {shape}: {relative_error}") + + +class TestBlockWiseINT8QuantFusion(unittest.TestCase): + """Test fused INT8 matmul + quantization kernels""" + + @unittest.skip("out_quant parameter not yet implemented in torch ops") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_linear_with_out_quant(self): + """Test INT8 linear operation with fused output quantization""" + batch_size = 4 + seq_len = 256 + input_dim = 1024 + output_dim = 2048 + block_size = 128 + + # Create input tensor + input_fp = torch.randn(batch_size, seq_len, input_dim, dtype=torch.float16, device='cuda') + weight_fp = torch.randn(output_dim, input_dim, dtype=torch.float16, device='cuda') + bias = torch.randn(output_dim, dtype=torch.float16, device='cuda') + + # Quantize input and weight + input_q = QuantizedTensor.from_float( + input_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=False + ) + + weight_q = QuantizedTensor.from_float( + weight_fp, + "BlockWiseINT8Layout", + block_size=block_size, + is_weight=True + ) + + # Test 1: Regular linear (float output) + output_float = torch.ops.aten.linear.default(input_q, weight_q, bias) + self.assertIsNotNone(output_float) + self.assertEqual(output_float.shape, (batch_size, seq_len, output_dim)) + + # Test 2: Linear with fused output quantization (out_quant=True) + output_quant = torch.ops.aten.linear.default( + input_q, weight_q, bias + ) + + self.assertIsInstance(output_quant, QuantizedTensor, "Output should be QuantizedTensor when out_quant=True") + self.assertEqual(output_quant._layout_type, "BlockWiseINT8Layout") + + # Verify scale shape matches activation format + expected_scale_shape = (batch_size, seq_len, output_dim // block_size) + actual_scale_shape = output_quant._layout_params['scale'].shape + self.assertEqual(actual_scale_shape, expected_scale_shape, "Scale shape should match activation format") + + # Dequantize and compare + output_dequant = output_quant.dequantize() + self.assertEqual(output_dequant.shape, (batch_size, seq_len, output_dim)) + + # Compare with float output + diff = (output_float - output_dequant).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + relative_error = (diff / (output_float.abs() + 1e-6)).mean().item() + + self.assertLess(relative_error, 0.15, f"Relative error too high: {relative_error}") + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_addmm_with_out_quant(self): + """Test INT8 addmm operation with fused output quantization""" + M, K, N = 512, 1024, 2048 + block_size = 128 + + # Create tensors + input_fp = torch.randn(M, K, dtype=torch.float16, device='cuda') + weight_fp = torch.randn(N, K, dtype=torch.float16, device='cuda') + bias = torch.randn(N, dtype=torch.float16, device='cuda') + + # Quantize + input_q = QuantizedTensor.from_float( + input_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=False + ) + weight_q = QuantizedTensor.from_float( + weight_fp, "BlockWiseINT8Layout", + block_size=block_size, is_weight=True + ) + + # Test with out_quant=True + output_quant = torch.ops.aten.addmm.default( + bias, input_q, weight_q.t() + ) + + output_quant = QuantizedTensor.from_float(output_quant, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + self.assertIsInstance(output_quant, QuantizedTensor, "Output should be QuantizedTensor when out_quant=True") + self.assertEqual(output_quant.shape, (M, N)) + self.assertEqual(output_quant._layout_type, "BlockWiseINT8Layout") + + # Verify it can be dequantized + output_dequant = output_quant.dequantize() + self.assertEqual(output_dequant.shape, (M, N)) + self.assertEqual(output_dequant.dtype, torch.float16) + + +# Benchmark tests (skipped by default) +class TestBlockWiseINT8Benchmarks(unittest.TestCase): + """Performance benchmark tests for BlockWiseINT8Layout""" + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_runtime_comparison(self): + """Benchmark INT8 quantized ops via torch dispatch (high-level API)""" + device = torch.device('cuda') + torch.manual_seed(42) + + # More comprehensive test configurations + test_configs = [ + {"name": "Tiny", "batch": 2, "seq": 8, "in_feat": 128, "out_feat": 256, "block": 64}, + {"name": "Small", "batch": 4, "seq": 16, "in_feat": 256, "out_feat": 512, "block": 128}, + {"name": "Medium", "batch": 8, "seq": 32, "in_feat": 512, "out_feat": 1024, "block": 128}, + {"name": "Large", "batch": 16, "seq": 64, "in_feat": 1024, "out_feat": 2048, "block": 128}, + {"name": "XL", "batch": 32, "seq": 128, "in_feat": 2048, "out_feat": 4096, "block": 128}, + {"name": "XXL", "batch": 64, "seq": 256, "in_feat": 4096, "out_feat": 4096, "block": 128}, + ] + + n_warmup = 10 + n_iters = 200 # More iterations for better averaging + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Check if Triton is available + try: + from comfy.int8_kernels import int8_gemm as triton_int8_gemm + print("✓ Using Triton INT8 kernels (optimized path)\n") + except ImportError: + print("⚠ Using PyTorch fallback (Triton not available)\n") + + results = [] + + for config in test_configs: + name = config["name"] + batch_size = config["batch"] + seq_len = config["seq"] + in_features = config["in_feat"] + out_features = config["out_feat"] + block_size = config["block"] + + print(f"{name}: batch={batch_size}, seq={seq_len}, in={in_features}, out={out_features}, block={block_size}") + + # Calculate FLOPS for this configuration + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k # 2 for multiply-add + + try: + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize using high-level API + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up - test full dispatch path + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Benchmark INT8 via torch dispatch (includes dispatch overhead + quantized output) + int8_times = [] + for _ in range(n_iters): + start = time.time() + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Also benchmark with dequantization to FP32 output (more realistic for some use cases) + int8_dequant_times = [] + for _ in range(n_iters): + start = time.time() + output = torch.nn.functional.linear(input_int8, weight_int8, bias) + if isinstance(output, QuantizedTensor): + output = output.dequantize() + torch.cuda.synchronize() + int8_dequant_times.append((time.time() - start) * 1000) + + # Benchmark FP32 reference + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + # Convert to torch tensors for statistics + int8_times = torch.tensor(int8_times) + int8_dequant_times = torch.tensor(int8_dequant_times) + fp32_times = torch.tensor(fp32_times) + + # Calculate statistics + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + int8_dequant_mean = int8_dequant_times.mean().item() + int8_dequant_std = int8_dequant_times.std().item() + int8_dequant_min = int8_dequant_times.min().item() + + fp32_mean = fp32_times.mean().item() + fp32_std = fp32_times.std().item() + fp32_min = fp32_times.min().item() + + speedup_int8 = fp32_mean / int8_mean + speedup_int8_dequant = fp32_mean / int8_dequant_mean + + print(f" INT8 (quantized out): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (dequant out): {int8_dequant_mean:.3f}±{int8_dequant_std:.3f} ms (min: {int8_dequant_min:.3f} ms) [{flops/int8_dequant_mean/1e9:.2f} GFLOPS]") + print(f" FP32 reference: {fp32_mean:.3f}±{fp32_std:.3f} ms (min: {fp32_min:.3f} ms) [{flops/fp32_mean/1e9:.2f} GFLOPS]") + print(f" Speedup (INT8 quantized/FP32): {speedup_int8:.2f}x") + print(f" Speedup (INT8 dequant/FP32): {speedup_int8_dequant:.2f}x") + print(f" Dequant overhead: {((int8_dequant_mean - int8_mean) / int8_mean * 100):.1f}%\n") + + results.append({ + "name": name, + "int8_mean": int8_mean, + "int8_dequant_mean": int8_dequant_mean, + "fp32_mean": fp32_mean, + "speedup_int8": speedup_int8, + "speedup_int8_dequant": speedup_int8_dequant, + "flops": flops, + }) + + # Clean up memory after each configuration + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + if 'int8_times' in locals(): + del int8_times, int8_dequant_times, fp32_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration\n") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Print summary + print("\n" + "=" * 60) + print("Summary:") + print("=" * 60) + for result in results: + print(f"{result['name']:8s}: INT8 {result['int8_mean']:.3f}ms, " + f"INT8+dequant {result['int8_dequant_mean']:.3f}ms, " + f"FP32 {result['fp32_mean']:.3f}ms, " + f"Speedup: {result['speedup_int8']:.2f}x (quantized), {result['speedup_int8_dequant']:.2f}x (dequant)") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_int8_vs_fp8_runtime(self): + """Benchmark INT8 vs FP8 runtime with comprehensive configs""" + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + has_fp8 = True + except (RuntimeError, AttributeError): + has_fp8 = False + + if not has_fp8: + print("⚠ FP8 dtypes not supported on this system, skipping comparison") + self.skipTest("FP8 not supported") + return + + device = torch.device('cuda') + torch.manual_seed(42) + + # More comprehensive test configurations + test_configs = [ + {"name": "Tiny", "batch": 2, "seq": 8, "in_feat": 128, "out_feat": 256, "block": 64}, + {"name": "Small", "batch": 4, "seq": 16, "in_feat": 256, "out_feat": 512, "block": 128}, + {"name": "Medium", "batch": 8, "seq": 32, "in_feat": 512, "out_feat": 1024, "block": 128}, + {"name": "Large", "batch": 16, "seq": 64, "in_feat": 1024, "out_feat": 2048, "block": 128}, + {"name": "XL", "batch": 32, "seq": 128, "in_feat": 2048, "out_feat": 4096, "block": 128}, + {"name": "XXL", "batch": 64, "seq": 256, "in_feat": 4096, "out_feat": 4096, "block": 128}, + {"name": "XXXL", "batch": 128, "seq": 512, "in_feat": 4096, "out_feat": 4096, "block": 128}, + ] + + n_warmup = 10 + n_iters = 200 # More iterations for better averaging + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}") + print("Note: INT8 uses fused bias, FP8 adds bias separately\n") + + results = [] + + for config in test_configs: + name = config["name"] + batch_size = config["batch"] + seq_len = config["seq"] + in_features = config["in_feat"] + out_features = config["out_feat"] + block_size = config["block"] + + print(f"{name}: batch={batch_size}, seq={seq_len}, in={in_features}, out={out_features}, block={block_size}") + + # Calculate FLOPS for this configuration + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k # 2 for multiply-add + + try: + # Create test data + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + # Quantize with INT8 + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Quantize with FP8 + input_fp8 = QuantizedTensor.from_float(input_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + weight_fp8 = QuantizedTensor.from_float(weight_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + _ = out_fp8 + bias + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark INT8 (with fused bias) - collect all times + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Benchmark FP8 (bias added separately) + fp8_times = [] + for _ in range(n_iters): + start = time.time() + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias is not None: + _ = out_fp8 + bias + torch.cuda.synchronize() + fp8_times.append((time.time() - start) * 1000) + + # Benchmark FP32 reference + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + # Convert to torch tensors for statistics + int8_times = torch.tensor(int8_times) + fp8_times = torch.tensor(fp8_times) + fp32_times = torch.tensor(fp32_times) + + # Calculate statistics + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + fp8_mean = fp8_times.mean().item() + fp8_std = fp8_times.std().item() + fp8_min = fp8_times.min().item() + + fp32_mean = fp32_times.mean().item() + fp32_std = fp32_times.std().item() + fp32_min = fp32_times.min().item() + + speedup_int8 = fp32_mean / int8_mean + speedup_fp8 = fp32_mean / fp8_mean + int8_vs_fp8 = fp8_mean / int8_mean + + print(f" INT8 (fused bias): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" FP8 (sep. bias): {fp8_mean:.3f}±{fp8_std:.3f} ms (min: {fp8_min:.3f} ms) [{flops/fp8_mean/1e9:.2f} GFLOPS]") + print(f" FP32 (fused bias): {fp32_mean:.3f}±{fp32_std:.3f} ms (min: {fp32_min:.3f} ms) [{flops/fp32_mean/1e9:.2f} GFLOPS]") + print(f" Speedup (INT8/FP32): {speedup_int8:.2f}x") + print(f" Speedup (FP8/FP32): {speedup_fp8:.2f}x") + + if int8_mean < fp8_mean: + print(f" ✓ INT8 is {int8_vs_fp8:.2f}x faster than FP8\n") + else: + print(f" ✓ FP8 is {1/int8_vs_fp8:.2f}x faster than INT8\n") + + results.append({ + "name": name, + "int8_mean": int8_mean, + "fp8_mean": fp8_mean, + "fp32_mean": fp32_mean, + "speedup_int8": speedup_int8, + "speedup_fp8": speedup_fp8, + "int8_vs_fp8": int8_vs_fp8, + "flops": flops, + }) + + # Clean up memory after each configuration + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + if has_fp8: + del input_fp8, weight_fp8 + if 'int8_times' in locals(): + del int8_times, fp8_times, fp32_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration\n") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Print summary + print("\n" + "=" * 60) + print("Summary:") + print("=" * 60) + for result in results: + print(f"{result['name']:8s}: INT8 {result['int8_mean']:.3f}ms, " + f"FP8 {result['fp8_mean']:.3f}ms, " + f"FP32 {result['fp32_mean']:.3f}ms, " + f"Speedup (INT8/FP32): {result['speedup_int8']:.2f}x, " + f"(FP8/FP32): {result['speedup_fp8']:.2f}x") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_quantization_dequantization_runtime(self): + """Benchmark quantization and dequantization operations""" + device = torch.device('cuda') + torch.manual_seed(42) + + n_warmup = 5 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Test configurations - various tensor sizes + test_configs = [ + {"name": "Small Weight", "shape": (512, 512), "is_weight": True, "block": 128}, + {"name": "Medium Weight", "shape": (2048, 2048), "is_weight": True, "block": 128}, + {"name": "Large Weight", "shape": (4096, 4096), "is_weight": True, "block": 128}, + {"name": "XL Weight", "shape": (8192, 8192), "is_weight": True, "block": 128}, + {"name": "Small Activation", "shape": (8, 64, 512), "is_weight": False, "block": 128}, + {"name": "Medium Activation", "shape": (16, 128, 2048), "is_weight": False, "block": 128}, + {"name": "Large Activation", "shape": (32, 256, 4096), "is_weight": False, "block": 128}, + {"name": "XL Activation", "shape": (64, 512, 4096), "is_weight": False, "block": 128}, + ] + + print("=" * 60) + print("INT8 BlockWise Quantization/Dequantization") + print("=" * 60) + + results_int8 = [] + + for config in test_configs: + name = config["name"] + shape = config["shape"] + is_weight = config["is_weight"] + block_size = config["block"] + + try: + # Create test tensor + tensor_fp32 = torch.randn(shape, dtype=torch.float32, device=device) + tensor_size_mb = tensor_fp32.numel() * tensor_fp32.element_size() / 1024 / 1024 + + print(f"\n{name}: shape={shape}, size={tensor_size_mb:.2f}MB") + + # Warm up + for _ in range(n_warmup): + qt = QuantizedTensor.from_float(tensor_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=is_weight) + _ = qt.dequantize() + torch.cuda.synchronize() + + # Benchmark quantization + quant_times = [] + for _ in range(n_iters): + start = time.time() + qt = QuantizedTensor.from_float(tensor_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=is_weight) + torch.cuda.synchronize() + quant_times.append((time.time() - start) * 1000) + + # Benchmark dequantization (reuse last quantized tensor) + dequant_times = [] + for _ in range(n_iters): + start = time.time() + _ = qt.dequantize() + torch.cuda.synchronize() + dequant_times.append((time.time() - start) * 1000) + + # Calculate statistics + quant_times = torch.tensor(quant_times) + dequant_times = torch.tensor(dequant_times) + + quant_mean = quant_times.mean().item() + quant_std = quant_times.std().item() + quant_min = quant_times.min().item() + + dequant_mean = dequant_times.mean().item() + dequant_std = dequant_times.std().item() + dequant_min = dequant_times.min().item() + + # Calculate throughput (GB/s) + quant_throughput = (tensor_size_mb / 1024) / (quant_mean / 1000) + dequant_throughput = (tensor_size_mb / 1024) / (dequant_mean / 1000) + + print(f" Quantization: {quant_mean:.3f}±{quant_std:.3f} ms (min: {quant_min:.3f} ms) [{quant_throughput:.2f} GB/s]") + print(f" Dequantization: {dequant_mean:.3f}±{dequant_std:.3f} ms (min: {dequant_min:.3f} ms) [{dequant_throughput:.2f} GB/s]") + print(f" Total roundtrip: {quant_mean + dequant_mean:.3f} ms") + + # Calculate memory savings + qt_memory = qt._qdata.element_size() * qt._qdata.numel() + qt_memory += qt._layout_params['scale'].element_size() * qt._layout_params['scale'].numel() + fp32_memory = tensor_fp32.element_size() * tensor_fp32.numel() + reduction = fp32_memory / qt_memory + + print(f" Memory: FP32 {fp32_memory/1024/1024:.2f}MB -> INT8 {qt_memory/1024/1024:.2f}MB ({reduction:.2f}x reduction)") + + results_int8.append({ + "name": name, + "shape": shape, + "size_mb": tensor_size_mb, + "quant_mean": quant_mean, + "dequant_mean": dequant_mean, + "quant_throughput": quant_throughput, + "dequant_throughput": dequant_throughput, + "reduction": reduction, + }) + + # Clean up memory after each configuration + del tensor_fp32, qt + if 'quant_times' in locals(): + del quant_times, dequant_times + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"\n{name}: ⚠ OOM - skipping") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Summary + print() + print("=" * 60) + print("Summary: INT8 Quantization/Dequantization Performance") + print("=" * 60) + for result in results_int8: + print(f"{result['name']:20s}: Quant {result['quant_mean']:.3f}ms, " + f"Dequant {result['dequant_mean']:.3f}ms, " + f"Total {result['quant_mean'] + result['dequant_mean']:.3f}ms") + + # Assertions for unittest + self.assertGreater(len(results_int8), 0, "Should have collected benchmark results") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_fp16_vs_int8_real_model_sizes(self): + """Compare FP16 vs INT8 vs FP8 on actual model sizes via torch dispatch""" + device = torch.device('cuda') + torch.manual_seed(42) + + # Check if FP8 is available + try: + test_tensor = torch.randn(16, 16, device='cuda', dtype=torch.float32) + _ = test_tensor.to(torch.float8_e4m3fn) + has_fp8 = True + print("✓ FP8 support detected") + except (RuntimeError, AttributeError): + has_fp8 = False + print("⚠ FP8 not supported on this system - will compare FP16 vs INT8 only") + + # Actual sizes from model dumps + test_configs = [ + # WAN 2.2 5B model sizes + { + "model": "WAN2.2-5B", + "name": "First layer (small batch)", + "input_shape": (2, 1, 3072), + "weight_shape": (18432, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "Attention layer (long seq)", + "input_shape": (2, 27280, 3072), + "weight_shape": (3072, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "MLP down projection (long seq)", + "input_shape": (2, 27280, 14336), + "weight_shape": (3072, 14336), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "MLP up projection (long seq)", + "input_shape": (2, 27280, 3072), + "weight_shape": (14336, 3072), + "block_size": 128, + }, + { + "model": "WAN2.2-5B", + "name": "Attention layer (medium seq)", + "input_shape": (2, 512, 3072), + "weight_shape": (3072, 3072), + "block_size": 128, + }, + # WAN 2.2 14B model sizes + { + "model": "WAN2.2-14B", + "name": "First layer (small batch)", + "input_shape": (2, 1, 5120), + "weight_shape": (30720, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "Attention layer (long seq)", + "input_shape": (2, 27280, 5120), + "weight_shape": (5120, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "Attention layer (medium seq)", + "input_shape": (2, 512, 5120), + "weight_shape": (5120, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "MLP up projection (long seq)", + "input_shape": (2, 27280, 5120), + "weight_shape": (13824, 5120), + "block_size": 128, + }, + { + "model": "WAN2.2-14B", + "name": "MLP down projection (long seq)", + "input_shape": (2, 27280, 13824), + "weight_shape": (5120, 13824), + "block_size": 128, + }, + ] + + n_warmup = 10 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + results = [] + current_model = None + + for config in test_configs: + model = config["model"] + name = config["name"] + input_shape = config["input_shape"] + weight_shape = config["weight_shape"] + block_size = config["block_size"] + + # Print model header when we switch models + if model != current_model: + print("\n" + "=" * 60) + print(f"{model} Model Layers") + print("=" * 60) + current_model = model + + print(f"\n{name}") + print(f" Input: {input_shape}, Weight: {weight_shape}") + + # Calculate FLOPS + batch, seq_len, in_features = input_shape + out_features, _ = weight_shape + m = batch * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + try: + # Measure initial VRAM + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + initial_vram = torch.cuda.memory_allocated() / 1024 / 1024 # MB + + # Create test data in FP16 and FP32 + input_fp32 = torch.randn(input_shape, dtype=torch.float32, device=device) + input_fp16 = input_fp32.to(torch.float16) + + weight_fp32 = torch.randn(weight_shape, dtype=torch.float32, device=device) + weight_fp16 = weight_fp32.to(torch.float16) + + bias_fp32 = torch.randn(out_features, dtype=torch.float32, device=device) + bias_fp16 = bias_fp32.to(torch.float16) + + # Measure FP16 VRAM + fp16_vram = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Quantize to INT8 + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Measure INT8 VRAM (after creating quantized tensors, before releasing FP16) + int8_vram_with_fp16 = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Quantize to FP8 if available + if has_fp8: + input_fp8 = QuantizedTensor.from_float(input_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + weight_fp8 = QuantizedTensor.from_float(weight_fp32, "TensorCoreFP8Layout", dtype=torch.float8_e4m3fn) + fp8_vram_with_others = torch.cuda.memory_allocated() / 1024 / 1024 - initial_vram + + # Calculate memory usage + fp16_input_mem = input_fp16.element_size() * input_fp16.numel() + fp16_weight_mem = weight_fp16.element_size() * weight_fp16.numel() + fp16_total_mem = fp16_input_mem + fp16_weight_mem + + int8_input_mem = input_int8._qdata.element_size() * input_int8._qdata.numel() + int8_input_mem += input_int8._layout_params['scale'].element_size() * input_int8._layout_params['scale'].numel() + + int8_weight_mem = weight_int8._qdata.element_size() * weight_int8._qdata.numel() + int8_weight_mem += weight_int8._layout_params['scale'].element_size() * weight_int8._layout_params['scale'].numel() + + int8_total_mem = int8_input_mem + int8_weight_mem + mem_reduction = fp16_total_mem / int8_total_mem + + print(f" Tensor Memory: FP16 {fp16_total_mem/1024/1024:.2f}MB -> INT8 {int8_total_mem/1024/1024:.2f}MB ({mem_reduction:.2f}x reduction)") + print(f" VRAM Usage: FP16 {fp16_vram:.2f}MB, INT8 {int8_vram_with_fp16:.2f}MB (incl. FP16 tensors)") + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + _ = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if has_fp8: + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + _ = out_fp8 + bias_fp32 + torch.cuda.synchronize() + + # Clear any warmup artifacts + torch.cuda.empty_cache() + + # Benchmark FP16 + fp16_times = [] + for _ in range(n_iters): + start = time.time() + output_fp16 = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + torch.cuda.synchronize() + fp16_times.append((time.time() - start) * 1000) + + # Benchmark INT8 (quantized output) + int8_times = [] + for _ in range(n_iters): + start = time.time() + output_int8 = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + # Benchmark INT8 with dequantization + int8_dequant_times = [] + for _ in range(n_iters): + start = time.time() + output_int8 = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if isinstance(output_int8, QuantizedTensor): + output_int8 = output_int8.dequantize() + torch.cuda.synchronize() + int8_dequant_times.append((time.time() - start) * 1000) + + # Benchmark FP8 if available + if has_fp8: + fp8_times = [] + for _ in range(n_iters): + start = time.time() + out_fp8 = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + out_fp8 = out_fp8 + bias_fp32 + # Dequantize if needed + if isinstance(out_fp8, QuantizedTensor): + out_fp8 = out_fp8.dequantize() + torch.cuda.synchronize() + fp8_times.append((time.time() - start) * 1000) + + # Clear benchmark outputs to free memory + if 'output_fp16' in locals(): + del output_fp16 + if 'output_int8' in locals(): + del output_int8 + if has_fp8 and 'out_fp8' in locals(): + del out_fp8 + torch.cuda.empty_cache() + + # Calculate statistics + fp16_times = torch.tensor(fp16_times) + int8_times = torch.tensor(int8_times) + int8_dequant_times = torch.tensor(int8_dequant_times) + + fp16_mean = fp16_times.mean().item() + fp16_std = fp16_times.std().item() + fp16_min = fp16_times.min().item() + + int8_mean = int8_times.mean().item() + int8_std = int8_times.std().item() + int8_min = int8_times.min().item() + + int8_dequant_mean = int8_dequant_times.mean().item() + int8_dequant_std = int8_dequant_times.std().item() + int8_dequant_min = int8_dequant_times.min().item() + + speedup_int8 = fp16_mean / int8_mean + speedup_int8_dequant = fp16_mean / int8_dequant_mean + + print(f" FP16: {fp16_mean:.3f}±{fp16_std:.3f} ms (min: {fp16_min:.3f} ms) [{flops/fp16_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (quantized): {int8_mean:.3f}±{int8_std:.3f} ms (min: {int8_min:.3f} ms) [{flops/int8_mean/1e9:.2f} GFLOPS]") + print(f" INT8 (dequantized): {int8_dequant_mean:.3f}±{int8_dequant_std:.3f} ms (min: {int8_dequant_min:.3f} ms) [{flops/int8_dequant_mean/1e9:.2f} GFLOPS]") + print(f" Speedup vs FP16: {speedup_int8:.2f}x (quantized), {speedup_int8_dequant:.2f}x (dequantized)") + + if has_fp8: + fp8_times = torch.tensor(fp8_times) + fp8_mean = fp8_times.mean().item() + fp8_std = fp8_times.std().item() + fp8_min = fp8_times.min().item() + speedup_fp8 = fp16_mean / fp8_mean + + print(f" FP8 (dequantized): {fp8_mean:.3f}±{fp8_std:.3f} ms (min: {fp8_min:.3f} ms) [{flops/fp8_mean/1e9:.2f} GFLOPS]") + print(f" Speedup vs FP16: {speedup_fp8:.2f}x") + else: + fp8_mean = None + speedup_fp8 = None + + # Precision check + output_fp16_check = torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16) + output_int8_check = torch.nn.functional.linear(input_int8, weight_int8, bias_fp32) + if isinstance(output_int8_check, QuantizedTensor): + output_int8_check = output_int8_check.dequantize() + + # Convert FP16 output to FP32 for comparison + output_fp16_check_fp32 = output_fp16_check.to(torch.float32) + + # Compare INT8 vs FP16 (both in FP32 for fair comparison) + error_int8 = ((output_int8_check - output_fp16_check_fp32).abs() / (output_fp16_check_fp32.abs() + 1e-6)).mean() + print(f" Precision: INT8 vs FP16 mean relative error: {error_int8:.6f}") + + if has_fp8: + output_fp8_check = torch.nn.functional.linear(input_fp8, weight_fp8, None) + if bias_fp32 is not None: + output_fp8_check = output_fp8_check + bias_fp32 + if isinstance(output_fp8_check, QuantizedTensor): + output_fp8_check = output_fp8_check.dequantize() + + error_fp8 = ((output_fp8_check - output_fp16_check_fp32).abs() / (output_fp16_check_fp32.abs() + 1e-6)).mean() + print(f" Precision: FP8 vs FP16 mean relative error: {error_fp8:.6f}") + else: + error_fp8 = None + + results.append({ + "model": model, + "name": name, + "input_shape": input_shape, + "weight_shape": weight_shape, + "fp16_mean": fp16_mean, + "int8_mean": int8_mean, + "int8_dequant_mean": int8_dequant_mean, + "fp8_mean": fp8_mean, + "speedup_int8": speedup_int8, + "speedup_int8_dequant": speedup_int8_dequant, + "speedup_fp8": speedup_fp8, + "mem_reduction": mem_reduction, + "error_int8": error_int8.item(), + "error_fp8": error_fp8.item() if error_fp8 is not None else None, + "fp16_vram": fp16_vram, + "int8_vram": int8_vram_with_fp16, + }) + + # Aggressive memory cleanup after each configuration to avoid OOM + # Delete input/weight tensors + del input_fp32, input_fp16, weight_fp32, weight_fp16, bias_fp32, bias_fp16 + del input_int8, weight_int8 + if has_fp8: + del input_fp8, weight_fp8 + + # Delete precision check outputs + if 'output_fp16_check' in locals(): + del output_fp16_check, output_fp16_check_fp32, output_int8_check + if has_fp8 and 'output_fp8_check' in locals(): + del output_fp8_check + + # Delete timing tensors + if 'fp16_times' in locals(): + del fp16_times, int8_times, int8_dequant_times + if has_fp8 and 'fp8_times' in locals(): + del fp8_times + + # Force Python garbage collection + gc.collect() + # Clear CUDA cache + torch.cuda.empty_cache() + # Synchronize to ensure cleanup is complete + torch.cuda.synchronize() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f" ⚠ OOM - skipping this configuration") + # Ultra-aggressive cleanup on OOM + # Delete any lingering tensors from failed iteration + for var_name in list(locals().keys()): + if 'tensor' in var_name.lower() or var_name.endswith(('_fp16', '_fp32', '_int8', '_fp8')): + try: + del locals()[var_name] + except: + pass + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + else: + raise + + # Summary table + print("\n" + "=" * 80) + if has_fp8: + print("Summary: FP16 vs INT8 vs FP8 Performance") + else: + print("Summary: FP16 vs INT8 Performance") + print("=" * 80) + + if results: + # Group results by model + models = {} + for result in results: + model = result["model"] + if model not in models: + models[model] = [] + models[model].append(result) + + # Print results grouped by model + for model_name, model_results in models.items(): + print(f"\n{model_name}:") + if has_fp8: + print(f"{'Layer':<25s} {'FP16':<10s} {'INT8':<10s} {'FP8':<10s} {'Speedup':<20s} {'Mem':<8s}") + else: + print(f"{'Layer':<30s} {'FP16 (ms)':<12s} {'INT8 (ms)':<12s} {'Speedup':<10s} {'Memory':<10s}") + print("-" * 80) + + for result in model_results: + layer_name = result["name"][:23] if has_fp8 else result["name"][:28] + if has_fp8 and result['fp8_mean'] is not None: + print(f"{layer_name:<25s} {result['fp16_mean']:>8.3f}ms {result['int8_dequant_mean']:>8.3f}ms {result['fp8_mean']:>8.3f}ms " + f"INT8:{result['speedup_int8_dequant']:>5.2f}x FP8:{result['speedup_fp8']:>5.2f}x {result['mem_reduction']:>6.2f}x") + else: + print(f"{layer_name:<30s} {result['fp16_mean']:>10.3f} {result['int8_dequant_mean']:>10.3f} {result['speedup_int8_dequant']:>8.2f}x {result['mem_reduction']:>8.2f}x") + + # Calculate per-model total + model_fp16_time = sum(r["fp16_mean"] for r in model_results) + model_int8_time = sum(r["int8_dequant_mean"] for r in model_results) + model_speedup_int8 = model_fp16_time / model_int8_time if model_int8_time > 0 else 0 + + print("-" * 80) + if has_fp8 and any(r['fp8_mean'] is not None for r in model_results): + model_fp8_time = sum(r["fp8_mean"] for r in model_results if r["fp8_mean"] is not None) + model_speedup_fp8 = model_fp16_time / model_fp8_time if model_fp8_time > 0 else 0 + print(f"{'SUBTOTAL':<25s} {model_fp16_time:>8.3f}ms {model_int8_time:>8.3f}ms {model_fp8_time:>8.3f}ms " + f"INT8:{model_speedup_int8:>5.2f}x FP8:{model_speedup_fp8:>5.2f}x") + else: + print(f"{'SUBTOTAL':<30s} {model_fp16_time:>10.3f} {model_int8_time:>10.3f} {model_speedup_int8:>8.2f}x") + + print(f" {model_name} avg memory reduction: {sum(r['mem_reduction'] for r in model_results) / len(model_results):.2f}x") + print(f" {model_name} avg INT8 precision error: {sum(r['error_int8'] for r in model_results) / len(model_results):.6f}") + if has_fp8 and any(r['error_fp8'] is not None for r in model_results): + fp8_errors = [r['error_fp8'] for r in model_results if r['error_fp8'] is not None] + if fp8_errors: + print(f" {model_name} avg FP8 precision error: {sum(fp8_errors) / len(fp8_errors):.6f}") + + # VRAM analysis + total_fp16_vram = sum(r['fp16_vram'] for r in model_results) + total_int8_vram = sum(r['int8_vram'] for r in model_results) + print(f" {model_name} VRAM usage: FP16 {total_fp16_vram:.2f}MB, INT8 {total_int8_vram:.2f}MB (during inference with both)") + + # Calculate overall totals + total_fp16_time = sum(r["fp16_mean"] for r in results) + total_int8_time = sum(r["int8_dequant_mean"] for r in results) + overall_speedup_int8 = total_fp16_time / total_int8_time if total_int8_time > 0 else 0 + + print("\n" + "=" * 80) + if has_fp8 and any(r['fp8_mean'] is not None for r in results): + total_fp8_time = sum(r["fp8_mean"] for r in results if r["fp8_mean"] is not None) + overall_speedup_fp8 = total_fp16_time / total_fp8_time if total_fp8_time > 0 else 0 + print(f"{'GRAND TOTAL':<25s} {total_fp16_time:>8.3f}ms {total_int8_time:>8.3f}ms {total_fp8_time:>8.3f}ms " + f"INT8:{overall_speedup_int8:>5.2f}x FP8:{overall_speedup_fp8:>5.2f}x") + else: + print(f"{'GRAND TOTAL':<30s} {total_fp16_time:>10.3f} {total_int8_time:>10.3f} {overall_speedup_int8:>8.2f}x") + print("=" * 80) + + print(f"\n✓ Overall INT8 speedup: {overall_speedup_int8:.2f}x faster than FP16") + if has_fp8 and any(r['fp8_mean'] is not None for r in results): + print(f"✓ Overall FP8 speedup: {overall_speedup_fp8:.2f}x faster than FP16") + print(f"✓ Average memory reduction: {sum(r['mem_reduction'] for r in results) / len(results):.2f}x") + print(f"✓ Average INT8 precision error: {sum(r['error_int8'] for r in results) / len(results):.6f}") + if has_fp8: + fp8_errors = [r['error_fp8'] for r in results if r['error_fp8'] is not None] + if fp8_errors: + print(f"✓ Average FP8 precision error: {sum(fp8_errors) / len(fp8_errors):.6f}") + + # Total VRAM + total_fp16_vram = sum(r['fp16_vram'] for r in results) + total_int8_vram = sum(r['int8_vram'] for r in results) + print(f"✓ Total VRAM: FP16 {total_fp16_vram:.2f}MB, INT8 {total_int8_vram:.2f}MB") + + # Assertions for unittest + self.assertGreater(len(results), 0, "Should have collected benchmark results") + self.assertGreater(overall_speedup_int8, 0.5, "INT8 should have reasonable performance") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_systematic_benchmark(self): + """Comprehensive systematic benchmark across multiple dimensions""" + device = torch.device('cuda') + torch.manual_seed(42) + + n_warmup = 10 + n_iters = 100 + + print(f"\nWarmup iterations: {n_warmup}") + print(f"Benchmark iterations: {n_iters}\n") + + # Test 1: Varying batch size (typical transformer forward pass) + print("=" * 60) + print("Dimension 1: Varying Batch Size") + print("=" * 60) + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] + seq_len = 64 + in_features = 1024 + out_features = 1024 + block_size = 128 + + for batch_size in batch_sizes: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"Batch={batch_size:3d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Batch={batch_size:3d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 2: Varying sequence length + print("=" * 60) + print("Dimension 2: Varying Sequence Length") + print("=" * 60) + seq_lengths = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + batch_size = 8 + in_features = 1024 + out_features = 1024 + block_size = 128 + + for seq_len in seq_lengths: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"SeqLen={seq_len:4d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"SeqLen={seq_len:4d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 3: Varying hidden dimensions + print("=" * 60) + print("Dimension 3: Varying Hidden Dimensions") + print("=" * 60) + hidden_dims = [256, 512, 768, 1024, 1536, 2048, 3072, 4096, 8192] + batch_size = 8 + seq_len = 64 + block_size = 128 + + for hidden_dim in hidden_dims: + try: + input_fp32 = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(hidden_dim, hidden_dim, dtype=torch.float32, device=device) + bias = torch.randn(hidden_dim, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + fp32_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_fp32, weight_fp32, bias) + torch.cuda.synchronize() + fp32_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + fp32_mean = torch.tensor(fp32_times).mean().item() + speedup = fp32_mean / int8_mean + + m = batch_size * seq_len + k = hidden_dim + n = hidden_dim + flops = 2 * m * n * k + + print(f"Hidden={hidden_dim:4d}: INT8 {int8_mean:.3f}ms, FP32 {fp32_mean:.3f}ms, Speedup: {speedup:.2f}x, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Hidden={hidden_dim:4d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + + # Test 4: Varying block size + print("=" * 60) + print("Dimension 4: Varying Block Size") + print("=" * 60) + block_sizes = [32, 64, 128, 256, 512] + batch_size = 8 + seq_len = 64 + in_features = 1024 + out_features = 1024 + + for block_size in block_sizes: + try: + input_fp32 = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32, device=device) + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32, device=device) + bias = torch.randn(out_features, dtype=torch.float32, device=device) + + input_int8 = QuantizedTensor.from_float(input_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=False) + weight_int8 = QuantizedTensor.from_float(weight_fp32, "BlockWiseINT8Layout", block_size=block_size, is_weight=True) + + # Warm up + for _ in range(n_warmup): + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + + # Benchmark + int8_times = [] + for _ in range(n_iters): + start = time.time() + _ = torch.nn.functional.linear(input_int8, weight_int8, bias) + torch.cuda.synchronize() + int8_times.append((time.time() - start) * 1000) + + int8_mean = torch.tensor(int8_times).mean().item() + int8_std = torch.tensor(int8_times).std().item() + + m = batch_size * seq_len + k = in_features + n = out_features + flops = 2 * m * n * k + + print(f"Block={block_size:3d}: INT8 {int8_mean:.3f}±{int8_std:.3f}ms, [{flops/int8_mean/1e9:.2f} GFLOPS]") + + # Clean up after each test + del input_fp32, weight_fp32, bias, input_int8, weight_int8 + gc.collect() + torch.cuda.empty_cache() + + except RuntimeError as e: + if "out of memory" in str(e): + print(f"Block={block_size:3d}: ⚠ OOM") + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + break + else: + raise + + print() + print("✓ Systematic benchmark completed!") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_benchmark(self): + """Benchmark INT8 GELU vs FP16 GELU""" + # See test_int8_gelu.py::benchmark_int8_gelu for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_systematic_benchmark(self): + """Systematic GELU benchmark across different dimensions""" + # See test_int8_gelu.py::benchmark_int8_gelu_systematic for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_gelu_real_model_sizes(self): + """Test FP16 vs INT8 GELU on actual model sizes""" + # See test_int8_gelu.py::test_fp16_vs_int8_real_model_sizes for full implementation + self.skipTest("Benchmark test - run separately") + + @unittest.skip("perf benchmark only") + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_quant_fusion_performance(self): + """Compare performance of fused vs separate quantization""" + # See test_int8_quant_fusion.py::test_performance_comparison for full implementation + self.skipTest("Benchmark test - run separately") + + if __name__ == "__main__": unittest.main()