@@ -113,8 +113,13 @@ while [ -v 1 ]; do
113113 TRITON_TEST_SKIPLIST_DIR=" $( mkdir -p " $2 " && cd " $2 " && pwd) "
114114 shift 2
115115 ;;
116+ --select-from-file)
117+ # Must be absolute
118+ TRITON_TEST_SELECTFILE=" $( realpath " $2 " ) "
119+ shift 2
120+ ;;
116121 --help)
117- err " Example usage: ./test-triton.sh [--core | --tutorial | --unit | --microbench | --softmax | --gemm | --attention | --venv | --skip-pip-install | --skip-pytorch-install | --reports | --reports-dir DIR | --warning-reports | --ignore-errors | --skip-list SKIPLIST"
122+ err " Example usage: ./test-triton.sh [--core | --tutorial | --unit | --microbench | --softmax | --gemm | --attention | --venv | --skip-pip-install | --skip-pytorch-install | --reports | --reports-dir DIR | --warning-reports | --ignore-errors | --skip-list SKIPLIST | --select-from-file SELECTFILE "
118123 ;;
119124 * )
120125 err " Unknown argument: $1 ."
@@ -181,6 +186,14 @@ run_unit_tests() {
181186 lit -v . || $TRITON_TEST_IGNORE_ERRORS
182187}
183188
189+ run_pytest_command () {
190+ if [[ -n " $TRITON_TEST_SELECTFILE " ]]; then
191+ pytest " $@ " --collect-only > /dev/null 2>&1 && pytest " $@ " || true
192+ else
193+ pytest " $@ "
194+ fi
195+ }
196+
184197run_core_tests () {
185198 echo " ***************************************************"
186199 echo " ****** Running Triton Core tests ******"
@@ -189,31 +202,31 @@ run_core_tests() {
189202 ensure_spirv_dis
190203
191204 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=language \
192- pytest -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=language/test_warp_specialization.py
205+ run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=language/test_warp_specialization.py
193206
194207 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=subprocess \
195- pytest -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu language/test_subprocess.py
208+ run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu language/test_subprocess.py
196209
197210 # run runtime tests serially to avoid race condition with cache handling.
198211 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=runtime \
199- pytest -k " not test_within_2gb" --verbose --device xpu runtime/ --ignore=runtime/test_cublas.py
212+ run_pytest_command -k " not test_within_2gb" --verbose --device xpu runtime/ --ignore=runtime/test_cublas.py
200213
201214 TRITON_TEST_SUITE=debug \
202- pytest --verbose -n ${PYTEST_MAX_PROCESSES:- 8} test_debug.py --forked --device xpu
215+ run_pytest_command --verbose -n ${PYTEST_MAX_PROCESSES:- 8} test_debug.py --forked --device xpu
203216
204217 # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
205218 TRITON_DISABLE_LINE_INFO=0 TRITON_TEST_SUITE=line_info \
206- pytest -k " not test_line_info_interpreter" --verbose --device xpu language/test_line_info.py
219+ run_pytest_command -k " not test_line_info_interpreter" --verbose --device xpu language/test_line_info.py
207220
208221 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=tools \
209- pytest -k " not test_disam_cubin" --verbose tools
222+ run_pytest_command -k " not test_disam_cubin" --verbose tools
210223
211224 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=intel \
212- pytest -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu intel/
225+ run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:- 8} --device xpu intel/
213226
214227 cd $TRITON_PROJ /third_party/intel/python/test
215228 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=third_party \
216- pytest --device xpu .
229+ run_pytest_command --device xpu .
217230}
218231
219232run_regression_tests () {
@@ -223,7 +236,7 @@ run_regression_tests() {
223236 cd $TRITON_PROJ /python/test/regression
224237
225238 TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=regression \
226- pytest -vvv -s --device xpu . --ignore=test_performance.py
239+ run_pytest_command -vvv -s --device xpu . --ignore=test_performance.py
227240}
228241
229242run_interpreter_tests () {
@@ -233,7 +246,7 @@ run_interpreter_tests() {
233246 cd $TRITON_PROJ /python/test/unit
234247
235248 TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter \
236- pytest -vvv -n ${PYTEST_MAX_PROCESSES:- 16} -m interpreter language/test_core.py language/test_standard.py \
249+ run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:- 16} -m interpreter language/test_core.py language/test_standard.py \
237250 language/test_random.py language/test_line_info.py --device cpu
238251}
239252
@@ -347,7 +360,7 @@ run_instrumentation_tests() {
347360
348361 TRITON_TEST_SUITE=instrumentation \
349362 TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_NAME} \
350- pytest -vvv --device xpu instrumentation/test_gpuhello.py
363+ run_pytest_command -vvv --device xpu instrumentation/test_gpuhello.py
351364}
352365
353366run_inductor_tests () {
0 commit comments