Skip to content

Commit 697283d

Browse files
authored
Add CreateQuantPybindModule into orttraining (microsoft#26401)
### Description As titled ### Motivation and Context As titled
1 parent 1f45838 commit 697283d

File tree

7 files changed

+5
-71
lines changed

7 files changed

+5
-71
lines changed

onnxruntime/test/python/quantization/test_op_matmul_2bits.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import tempfile
99
import unittest
10-
from importlib.util import find_spec
1110
from pathlib import Path
1211

1312
import numpy as np
@@ -205,9 +204,6 @@ def quant_test(
205204
else:
206205
raise exception
207206

208-
@unittest.skipIf(
209-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
210-
)
211207
def test_quantize_matmul_int2_symmetric(self):
212208
np.random.seed(13)
213209

@@ -216,9 +212,6 @@ def test_quantize_matmul_int2_symmetric(self):
216212
data_reader = self.input_feeds(1, {"input": (100, 52)})
217213
self.quant_test(model_fp32_path, data_reader, 32, True, rtol=0.02, atol=0.1)
218214

219-
@unittest.skipIf(
220-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
221-
)
222215
def test_quantize_matmul_int2_offsets(self):
223216
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
224217
self.construct_model_matmul(model_fp32_path, symmetric=False)

onnxruntime/test/python/quantization/test_op_matmul_4bits.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,6 @@ def quant_test_with_algo(
295295
else:
296296
raise exception
297297

298-
@unittest.skipIf(
299-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
300-
)
301298
def test_quantize_matmul_int4_symmetric(self):
302299
np.random.seed(13)
303300

@@ -306,18 +303,12 @@ def test_quantize_matmul_int4_symmetric(self):
306303
data_reader = self.input_feeds(1, {"input": (100, 52)})
307304
self.quant_test(model_fp32_path, data_reader, 32, True)
308305

309-
@unittest.skipIf(
310-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
311-
)
312306
def test_quantize_matmul_int4_offsets(self):
313307
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
314308
self.construct_model_matmul(model_fp32_path, symmetric=False)
315309
data_reader = self.input_feeds(1, {"input": (100, 52)})
316310
self.quant_test(model_fp32_path, data_reader, 32, False)
317311

318-
@unittest.skipIf(
319-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
320-
)
321312
def test_quantize_gather_int4_symmetric(self):
322313
np.random.seed(13)
323314

@@ -327,19 +318,13 @@ def test_quantize_gather_int4_symmetric(self):
327318
# cover rounding error
328319
self.quant_test(model_fp32_path, data_reader, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5)
329320

330-
@unittest.skipIf(
331-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
332-
)
333321
def test_quantize_gather_int4_offsets(self):
334322
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_offset.onnx").absolute())
335323
self.construct_model_gather(model_fp32_path, False, TensorProto.FLOAT16, TensorProto.INT64)
336324
data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int64)
337325
# cover rounding error
338326
self.quant_test(model_fp32_path, data_reader, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5)
339327

340-
@unittest.skipIf(
341-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
342-
)
343328
def test_quantize_matmul_int4_symmetric_qdq(self):
344329
np.random.seed(13)
345330

@@ -348,18 +333,12 @@ def test_quantize_matmul_int4_symmetric_qdq(self):
348333
data_reader = self.input_feeds(1, {"input": (100, 52)})
349334
self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ)
350335

351-
@unittest.skipIf(
352-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
353-
)
354336
def test_quantize_matmul_int4_offsets_qdq(self):
355337
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
356338
self.construct_model_matmul(model_fp32_path, symmetric=False)
357339
data_reader = self.input_feeds(1, {"input": (100, 52)})
358340
self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ)
359341

360-
@unittest.skipIf(
361-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
362-
)
363342
def test_quantize_matmul_int4_using_rtn_algo(self):
364343
if not find_spec("neural_compressor"):
365344
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
@@ -370,9 +349,6 @@ def test_quantize_matmul_int4_using_rtn_algo(self):
370349
data_reader = self.input_feeds(1, {"input": (100, 52)})
371350
self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False)
372351

373-
@unittest.skipIf(
374-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
375-
)
376352
def test_quantize_matmul_int4_using_gptq_algo(self):
377353
if not find_spec("neural_compressor"):
378354
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
@@ -383,9 +359,6 @@ def test_quantize_matmul_int4_using_gptq_algo(self):
383359
data_reader = self.input_feeds(1, {"input": (100, 52)})
384360
self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)
385361

386-
@unittest.skipIf(
387-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
388-
)
389362
def test_quantize_matmul_int4_using_hqq_algo(self):
390363
if not find_spec("torch"):
391364
self.skipTest("skip test_hqq_quant since torch is not installed")

onnxruntime/test/python/quantization/test_op_matmul_bnb4.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import tempfile
99
import unittest
10-
from importlib.util import find_spec
1110
from pathlib import Path
1211

1312
import numpy as np
@@ -166,16 +165,10 @@ def quant_test(self, quant_type: int, block_size: int):
166165
except Exception as exception:
167166
raise exception
168167

169-
@unittest.skipIf(
170-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
171-
)
172168
def test_quantize_matmul_bnb4_fp4(self):
173169
np.random.seed(13)
174170
self.quant_test(0, 64)
175171

176-
@unittest.skipIf(
177-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
178-
)
179172
def test_quantize_matmul_bnb4_nf4(self):
180173
np.random.seed(13)
181174
self.quant_test(1, 64)

onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# --------------------------------------------------------------------------
77

88
import unittest
9-
from importlib.util import find_spec
109

1110
import numpy as np
1211
import numpy.typing as npt
@@ -99,9 +98,6 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int
9998

10099

101100
class TestQuantizeBlockwise4Bits(unittest.TestCase):
102-
@unittest.skipIf(
103-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
104-
)
105101
def test_quantize_blockwise_4bits(self):
106102
for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]:
107103
for block_size in [16, 32, 64, 128]:

onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# --------------------------------------------------------------------------
77

88
import unittest
9-
from importlib.util import find_spec
109

1110
import numpy as np
1211
import numpy.typing as npt
@@ -120,9 +119,6 @@ def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int,
120119

121120

122121
class TestQuantizeBlockwiseBnb4(unittest.TestCase):
123-
@unittest.skipIf(
124-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
125-
)
126122
def test_quantize_blockwise_bnb4(self):
127123
for quant_type in ["FP4", "NF4"]:
128124
for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]:

onnxruntime/test/python/transformers/test_generation.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import shutil
1010
import unittest
11-
from importlib.util import find_spec
1211

1312
import onnx
1413
import pytest
@@ -21,16 +20,12 @@
2120
from benchmark_helper import Precision
2221
from convert_generation import main as run
2322
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
24-
25-
if not find_spec("onnxruntime.training"):
26-
from models.whisper.convert_to_onnx import main as run_whisper
23+
from models.whisper.convert_to_onnx import main as run_whisper
2724
else:
2825
from onnxruntime.transformers.benchmark_helper import Precision
2926
from onnxruntime.transformers.convert_generation import main as run
3027
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
31-
32-
if not find_spec("onnxruntime.training"):
33-
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper
28+
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper
3429

3530

3631
def has_cuda_environment():
@@ -514,33 +509,21 @@ def run_configs(self, optional_arguments):
514509
if "--model_impl" not in arguments:
515510
self.run_export(arguments)
516511

517-
@unittest.skipIf(
518-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
519-
)
520512
@pytest.mark.slow
521513
def test_required_args(self):
522514
optional_args = []
523515
self.run_configs(optional_args)
524516

525-
@unittest.skipIf(
526-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
527-
)
528517
@pytest.mark.slow
529518
def test_forced_decoder_ids(self):
530519
decoder_input_ids = ["--use_forced_decoder_ids"]
531520
self.run_configs(decoder_input_ids)
532521

533-
@unittest.skipIf(
534-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
535-
)
536522
@pytest.mark.slow
537523
def test_logits_processor(self):
538524
logits_processor = ["--use_logits_processor"]
539525
self.run_configs(logits_processor)
540526

541-
@unittest.skipIf(
542-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
543-
)
544527
@pytest.mark.slow
545528
def test_cross_qk_overall(self):
546529
cross_qk_input_args = [
@@ -557,9 +540,6 @@ def test_cross_qk_overall(self):
557540
]
558541
self.run_configs(cross_qk_input_args + cross_qk_output_args)
559542

560-
@unittest.skipIf(
561-
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits"
562-
)
563543
@pytest.mark.slow
564544
def test_openai_impl_whisper(self):
565545
optional_args = ["--model_impl", "openai"]

orttraining/orttraining/python/orttraining_python_module.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ Status CreateTrainingPybindStateModule(py::module& m) {
300300
return Status::OK();
301301
}
302302

303+
void CreateQuantPybindModule(py::module& m);
304+
303305
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
304306
auto st = CreateTrainingPybindStateModule(m);
305307
if (!st.IsOK())
@@ -332,6 +334,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
332334
"Clean the execution provider instances used in ort training module.");
333335

334336
m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; });
337+
CreateQuantPybindModule(m);
335338
}
336339

337340
} // namespace python

0 commit comments

Comments
 (0)