Skip to content

Commit f303f4c

Browse files
authored
Only convert to int4 preshuffled tensor in H100 (#3245)
Summary: A minor fix for `convert_to_packed_tensor_based_on_current_hardware` to only convert the Int4Tensor to Int4PreshuffledTensor when we are on H100 GPU Test Plan: pytest test/prototype/test_tensor_conversion.py Reviewers: Subscribers: Tasks: Tags:
1 parent e9c7bea commit f303f4c

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

test/prototype/test_tensor_conversion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
_is_kernel_library_loaded,
3535
)
3636
from torchao.quantization.utils import compute_error
37-
from torchao.utils import _is_fbgemm_gpu_genai_available
37+
from torchao.utils import (
38+
_is_fbgemm_gpu_genai_available,
39+
is_sm_at_least_90,
40+
)
3841

3942

4043
class ToyLinearModelWithTiedEmbedding(torch.nn.Module):
@@ -206,5 +209,9 @@ def test_int4_tensor_conversion():
206209
convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False
207210
)
208211
after_conversion = m(*example_inputs)
209-
assert isinstance(m[0].weight, Int4PreshuffledTensor)
212+
if is_sm_at_least_90():
213+
assert isinstance(m[0].weight, Int4PreshuffledTensor)
214+
else:
215+
assert isinstance(m[0].weight, Int4Tensor)
216+
210217
assert torch.equal(before_conversion, after_conversion)

torchao/prototype/tensor_conversion/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
Int4Tensor,
1515
IntxUnpackedToInt8Tensor,
1616
)
17-
from torchao.utils import TorchAOBaseTensor, _is_fbgemm_gpu_genai_available
17+
from torchao.utils import (
18+
TorchAOBaseTensor,
19+
_is_fbgemm_gpu_genai_available,
20+
is_sm_at_least_90,
21+
)
1822

1923

2024
def _convert_linear_weight_to_int8_lut_tensor(module):
@@ -187,6 +191,7 @@ def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor
187191
isinstance(tensor, Int4Tensor)
188192
and is_device("cuda", tensor.device)
189193
and _is_fbgemm_gpu_genai_available()
194+
and is_sm_at_least_90()
190195
):
191196
return Int4PreshuffledTensor.from_int4_tensor(tensor)
192197
return tensor

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
import importlib.util
98
from typing import List, Optional
109

1110
import torch
1211

1312
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
1413
from torchao.utils import (
1514
TorchAOBaseTensor,
15+
_is_fbgemm_gpu_genai_available,
1616
)
1717

1818
__all__ = [
@@ -22,10 +22,7 @@
2222
aten = torch.ops.aten
2323

2424

25-
if (
26-
importlib.util.find_spec("fbgemm_gpu") is None
27-
or importlib.util.find_spec("fbgemm_gpu.experimental") is None
28-
):
25+
if not _is_fbgemm_gpu_genai_available():
2926
quantize_int4_preshuffle = None
3027
quantize_fp8_row = None
3128
pack_int4 = None

0 commit comments

Comments
 (0)