Skip to content

Commit 564e0f7

Browse files
fix: gpu tasks are unset if <= 0 (#347)
previously, `--ntasks-per-gpu` was set, despite being < 1. This is a fix for #342 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - Bug Fixes - Simplified task flag behavior: task-related SBATCH flags (--ntasks / --ntasks-per-gpu) are no longer conditionally injected in the previous path, producing consistent resource sizing for GPU and non‑GPU jobs. - Unset or falsy task values now reliably default to 1, preventing empty or invalid task flags. - CPU-related settings are now included in the generated SBATCH command. - Tests - Added tests validating presence or omission of --ntasks and --ntasks-per-gpu across GPU/non‑GPU and unset scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent adcd86e commit 564e0f7

File tree

3 files changed

+125
-11
lines changed

3 files changed

+125
-11
lines changed

snakemake_executor_plugin_slurm/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,6 @@ def run_job(self, job: JobExecutorInterface):
328328
"- submitting without. This might or might not work on your cluster."
329329
)
330330

331-
# fixes #40 - set ntasks regardless of mpi, because
332-
# SLURM v22.05 introduced the requirement for all jobs
333-
gpu_job = job.resources.get("gpu") or "gpu" in job.resources.get("gres", "")
334-
if gpu_job:
335-
call += f" --ntasks-per-gpu={job.resources.get('tasks', 1)}"
336-
else:
337-
call += f" --ntasks={job.resources.get('tasks', 1)}"
338331
# MPI job
339332
if job.resources.get("mpi", False):
340333
if not job.resources.get("tasks_per_node") and not job.resources.get(

snakemake_executor_plugin_slurm/submit_string.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,18 @@ def get_submit_command(job, params):
5555
# fixes #316 - allow unsetting of tasks per gpu
5656
# apparently, python's internal process manangement interfers with SLURM
5757
# e.g. for pytorch
58-
ntasks_per_gpu = job.resources.get(
59-
"tasks_per_gpu", job.resources.get("tasks", 1)
60-
)
58+
ntasks_per_gpu = job.resources.get("tasks_per_gpu")
59+
if ntasks_per_gpu is None:
60+
ntasks_per_gpu = job.resources.get("tasks")
61+
if ntasks_per_gpu is None:
62+
ntasks_per_gpu = 1
63+
6164
if ntasks_per_gpu >= 1:
6265
call += f" --ntasks-per-gpu={ntasks_per_gpu}"
6366
else:
6467
# fixes #40 - set ntasks regardless of mpi, because
6568
# SLURM v22.05 will require it for all jobs
66-
call += f" --ntasks={job.resources.get('tasks', 1)}"
69+
call += f" --ntasks={job.resources.get('tasks') or 1}"
6770

6871
# we need to set cpus-per-task OR cpus-per-gpu, the function
6972
# will return a string with the corresponding value

tests/tests.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,124 @@ def test_empty_qos(self, mock_job):
520520
# Assert the qoes is included (even if empty)
521521
assert "--qos=''" in get_submit_command(job, params)
522522

523+
def test_taks(self, mock_job):
524+
"""Test that tasks are correctly included in the sbatch command."""
525+
# Create a job with tasks
526+
job = mock_job(tasks=4)
527+
params = {
528+
"run_uuid": "test_run",
529+
"slurm_logfile": "test_logfile",
530+
"comment_str": "test_comment",
531+
"account": None,
532+
"partition": None,
533+
"workdir": ".",
534+
"tasks": 4,
535+
}
536+
537+
# Patch subprocess.Popen to capture the sbatch command
538+
with patch("subprocess.Popen") as mock_popen:
539+
# Configure the mock to return successful submission
540+
process_mock = MagicMock()
541+
process_mock.communicate.return_value = ("123", "")
542+
process_mock.returncode = 0
543+
mock_popen.return_value = process_mock
544+
545+
assert "--ntasks=4" in get_submit_command(job, params)
546+
547+
def test_gpu_tasks(self, mock_job):
548+
"""Test that GPU tasks are correctly included in the sbatch command."""
549+
# Create a job with GPU tasks
550+
job = mock_job(gpu=1, tasks_per_gpu=2)
551+
params = {
552+
"run_uuid": "test_run",
553+
"slurm_logfile": "test_logfile",
554+
"comment_str": "test_comment",
555+
"account": None,
556+
"partition": None,
557+
"workdir": ".",
558+
"tasks_per_gpu": 2,
559+
}
560+
561+
# Patch subprocess.Popen to capture the sbatch command
562+
with patch("subprocess.Popen") as mock_popen:
563+
# Configure the mock to return successful submission
564+
process_mock = MagicMock()
565+
process_mock.communicate.return_value = ("123", "")
566+
process_mock.returncode = 0
567+
mock_popen.return_value = process_mock
568+
569+
assert "--ntasks-per-gpu=2" in get_submit_command(job, params)
570+
571+
def test_no_gpu_task(self, mock_job):
572+
"""Test that no GPU tasks are included when not specified."""
573+
# Create a job without GPU tasks
574+
job = mock_job(gpu=1, tasks_per_gpu=-1)
575+
params = {
576+
"run_uuid": "test_run",
577+
"slurm_logfile": "test_logfile",
578+
"comment_str": "test_comment",
579+
"account": None,
580+
"partition": None,
581+
"workdir": ".",
582+
"tasks_per_gpu": -1,
583+
}
584+
585+
# Patch subprocess.Popen to capture the sbatch command
586+
with patch("subprocess.Popen") as mock_popen:
587+
# Configure the mock to return successful submission
588+
process_mock = MagicMock()
589+
process_mock.communicate.return_value = ("123", "")
590+
process_mock.returncode = 0
591+
mock_popen.return_value = process_mock
592+
593+
assert "--ntasks-per-gpu" not in get_submit_command(job, params)
594+
595+
def test_task_set_for_unset_tasks(self, mock_job):
596+
"""Test that tasks are set to 1 when unset."""
597+
# Create a job without tasks
598+
job = mock_job(tasks=None)
599+
params = {
600+
"run_uuid": "test_run",
601+
"slurm_logfile": "test_logfile",
602+
"comment_str": "test_comment",
603+
"account": None,
604+
"partition": None,
605+
"workdir": ".",
606+
}
607+
608+
# Patch subprocess.Popen to capture the sbatch command
609+
with patch("subprocess.Popen") as mock_popen:
610+
# Configure the mock to return successful submission
611+
process_mock = MagicMock()
612+
process_mock.communicate.return_value = ("123", "")
613+
process_mock.returncode = 0
614+
mock_popen.return_value = process_mock
615+
616+
assert "--ntasks=1" in get_submit_command(job, params)
617+
618+
def test_gpu_tasks_set_for_unset_tasks(self, mock_job):
619+
"""Test that GPU tasks are set to 1 when unset."""
620+
# Create a job without GPU tasks
621+
job = mock_job(gpu=1)
622+
params = {
623+
"run_uuid": "test_run",
624+
"slurm_logfile": "test_logfile",
625+
"comment_str": "test_comment",
626+
"account": None,
627+
"partition": None,
628+
"workdir": ".",
629+
}
630+
631+
# Patch subprocess.Popen to capture the sbatch command
632+
with patch("subprocess.Popen") as mock_popen:
633+
# Configure the mock to return successful submission
634+
process_mock = MagicMock()
635+
process_mock.communicate.return_value = ("123", "")
636+
process_mock.returncode = 0
637+
mock_popen.return_value = process_mock
638+
639+
assert "--ntasks-per-gpu=1" in get_submit_command(job, params)
640+
523641

524642
class TestWildcardsWithSlashes(snakemake.common.tests.TestWorkflowsLocalStorageBase):
525643
"""

0 commit comments

Comments
 (0)