|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import os |
| 6 | +import re |
6 | 7 | import sys |
7 | 8 | from pathlib import Path |
8 | 9 |
|
@@ -345,3 +346,59 @@ def test_squeue_output(self) -> None: |
345 | 346 | assert job.user != "" |
346 | 347 | assert job.status in ["R", "PD"] |
347 | 348 | assert job.partition != "" |
| 349 | + |
| 350 | + |
| 351 | +class TestGresParsingWithSocket: |
| 352 | + """Test GRES parsing with socket information (PR #12).""" |
| 353 | + |
| 354 | + def test_parse_gres_with_socket_info(self) -> None: |
| 355 | + """Test parsing GRES string with socket information like 'gpu:8(S:0-1)'.""" |
| 356 | + # Test case from PR #12 - GPU count with socket information |
| 357 | + gres_with_socket = "gpu:8(S:0-1)" |
| 358 | + expected_gpu_count = 8 |
| 359 | + # Remove socket information using the regex from PR #12 |
| 360 | + cleaned_gres = re.sub(r"\(S:[0-9-]+\)", "", gres_with_socket) |
| 361 | + gpu_parts = cleaned_gres.split(":") |
| 362 | + |
| 363 | + # Verify the GPU count is parsed correctly |
| 364 | + assert gpu_parts[-1] == str(expected_gpu_count) |
| 365 | + assert int(gpu_parts[-1]) == expected_gpu_count |
| 366 | + |
| 367 | + def test_parse_gres_without_socket(self) -> None: |
| 368 | + """Test that regular GRES strings still work.""" |
| 369 | + # Regular format without socket info |
| 370 | + gres_regular = "gpu:4" |
| 371 | + expected_gpu_count = 4 |
| 372 | + cleaned_gres = re.sub(r"\(S:[0-9-]+\)", "", gres_regular) |
| 373 | + gpu_parts = cleaned_gres.split(":") |
| 374 | + |
| 375 | + assert gpu_parts[-1] == str(expected_gpu_count) |
| 376 | + assert int(gpu_parts[-1]) == expected_gpu_count |
| 377 | + |
| 378 | + def test_parse_gres_with_model_and_socket(self) -> None: |
| 379 | + """Test GRES with GPU model and socket info.""" |
| 380 | + # Format with GPU model and socket info |
| 381 | + gres_with_model = "gpu:v100:8(S:0-1)" |
| 382 | + expected_gpu_count = 8 |
| 383 | + cleaned_gres = re.sub(r"\(S:[0-9-]+\)", "", gres_with_model) |
| 384 | + gpu_parts = cleaned_gres.split(":") |
| 385 | + |
| 386 | + # The GPU count is still the last part |
| 387 | + assert gpu_parts[-1] == str(expected_gpu_count) |
| 388 | + assert int(gpu_parts[-1]) == expected_gpu_count |
| 389 | + assert gpu_parts[1] == "v100" # Model name preserved |
| 390 | + |
| 391 | + def test_parse_gres_multiple_sockets(self) -> None: |
| 392 | + """Test GRES with different socket patterns.""" |
| 393 | + # Different socket patterns that might appear |
| 394 | + test_cases = [ |
| 395 | + ("gpu:16(S:0-3)", 16), |
| 396 | + ("gpu:a100:4(S:0)", 4), |
| 397 | + ("gpu:2(S:1)", 2), |
| 398 | + ("gpu:v100:32(S:0-7)", 32), |
| 399 | + ] |
| 400 | + |
| 401 | + for gres, expected_count in test_cases: |
| 402 | + cleaned_gres = re.sub(r"\(S:[0-9-]+\)", "", gres) |
| 403 | + gpu_parts = cleaned_gres.split(":") |
| 404 | + assert int(gpu_parts[-1]) == expected_count |
0 commit comments