Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit f978bcf

Browse files
qbits support torch version compatiblity check (#1607)
* qbits support torch version compatiblity check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6042826 commit f978bcf

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

docs/qbits.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,8 @@ If user wants to use QBits, the Pytorch version must meet ITREX requirements, he
7474
| v1.4 | 2.2.0+cpu |
7575
| v1.4.1 | 2.2.0+cpu |
7676
| v1.4.2 | 2.3.0+cpu |
77+
78+
Users can also check whether the current torch version is compatible with QBits by using the `check_torch_compatibility` function provided by QBits.
79+
```python
80+
assert qbits.check_torch_compatibility(str(torch.__version__))
81+
```

intel_extension_for_transformers/qbits/CMakeLists.txt

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,27 @@ project(qbits_py LANGUAGES C CXX)
1616

1717

1818
set(QBITS_TORCH_PATH "" CACHE STRING "Torch install path")
19+
set(torch_info "")
20+
21+
function(get_torch_info python_command)
22+
set(import_torch "import torch:")
23+
string(REPLACE ":" ";" import_torch ${import_torch})
24+
string(CONCAT fin_command "${import_torch}" "${python_command}")
25+
execute_process(COMMAND python -c "${fin_command}"
26+
OUTPUT_VARIABLE torch_info
27+
OUTPUT_STRIP_TRAILING_WHITESPACE)
28+
set(torch_info "${torch_info}" PARENT_SCOPE)
29+
endfunction()
30+
1931

2032
if(QBITS_TORCH_PATH)
2133
set(torch_path ${QBITS_TORCH_PATH})
2234
unset(TORCH_LIBRARY CACHE) # force find_package torch
2335
unset(c10_LIBRARY CACHE)
2436
unset(TORCH_DIR CACHE)
2537
else()
26-
execute_process(COMMAND python -c "import torch; print(torch.__path__[0])"
27-
OUTPUT_VARIABLE torch_path
28-
OUTPUT_STRIP_TRAILING_WHITESPACE)
38+
get_torch_info("print(torch.__path__[0])")
39+
set(torch_path "${torch_info}")
2940
endif()
3041

3142
find_package(Torch REQUIRED
@@ -48,6 +59,10 @@ add_compile_options(-flto=auto)
4859

4960
# Link against LibTorch
5061
pybind11_add_module(qbits_py ${qbits_src})
62+
get_torch_info("print(torch.__version__)")
63+
set(torch_version "${torch_info}")
5164
target_compile_features(qbits_py PRIVATE cxx_std_14)
65+
set(TORCH_VERSION_MACRO COMPATIBLE_TORCH_VERSION="${torch_version}")
66+
target_compile_definitions(qbits_py PUBLIC ${TORCH_VERSION_MACRO})
5267
target_link_directories(qbits_py PRIVATE ${torch_path}/lib)
5368
target_link_libraries(qbits_py PRIVATE bestla_dispatcher torch_python)

intel_extension_for_transformers/qbits/qbits.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ static void woq_linear(const torch::Tensor& activation, const torch::Tensor& wei
114114
torch::Tensor& output, const std::string& compute_type, const std::string& weight_type,
115115
const std::string& scale_type, bool asym) {
116116
woq::woq_config_param p;
117-
118117
torch::Tensor bias_fp32;
119118
torch::Tensor* rt_bias = bias.numel() == 0 ? &output : const_cast<torch::Tensor*>(&bias);
120119
if (bias.scalar_type() != torch::kFloat32 && bias.numel() != 0) {
@@ -180,6 +179,16 @@ static bool check_isa_supported(std::string isa) {
180179
return false;
181180
}
182181

182+
static bool check_torch_compatibility(std::string version) {
183+
static std::string expected_version = COMPATIBLE_TORCH_VERSION;
184+
if (version == expected_version) {
185+
return true;
186+
}
187+
TORCH_CHECK(false,
188+
"QBits: Detected non QBits compiled version Torch, expected" + expected_version + ", but got " + version);
189+
return false;
190+
}
191+
183192
PYBIND11_MODULE(qbits_py, m) {
184193
m.def("quantize_to_packed_weight", &quantize_to_packed_weight);
185194
m.def("woq_linear", &woq_linear);
@@ -193,4 +202,5 @@ PYBIND11_MODULE(qbits_py, m) {
193202
m.def("dropout_fwd", &qbits_dropout_fwd);
194203
m.def("dropout_bwd", &qbits_dropout_bwd);
195204
m.def("check_isa_supported", &check_isa_supported);
205+
m.def("check_torch_compatibility", &check_torch_compatibility);
196206
}

intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
@pytest.mark.parametrize("src_dt", ["fp32", "bf16"])
4242
@pytest.mark.parametrize("dst_dt", ["fp32", "bf16"])
4343
def test(m, n, k, blocksize, compute_type, weight_type, scale_type, asym, transpose, add_bias, src_dt, dst_dt, dump_tensor_info=True):
44+
assert qbits.check_torch_compatibility(str(torch.__version__))
4445
if compute_type == "int8" and weight_type == "int8" and (not qbits.check_isa_supported("AVX_VNNI")):
4546
pytest.skip()
4647
if compute_type not in cmpt_configs[weight_type] or scale_type not in scale_configs[weight_type]:

intel_extension_for_transformers/qbits/run_build.sh

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)