Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 103 additions & 37 deletions slurm_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from datetime import datetime, timedelta, timezone
from getpass import getuser
from pathlib import Path
from typing import Annotated, Any, NamedTuple
from typing import Annotated, Any, Callable, Literal, Mapping, NamedTuple, cast

import polars as pl
import typer
Expand Down Expand Up @@ -290,7 +290,7 @@ def run_squeue() -> CommandResult:
CommandResult with stdout, stderr, and return code

"""
cmd = ["squeue", "-ro", "%u/%t/%D/%P/%C/%N/%h"]
cmd = ["squeue", "-ro", "%u/%t/%D/%P/%C/%N/%h/%m"]
return _run(cmd)


Expand Down Expand Up @@ -930,11 +930,12 @@ class SlurmJob(NamedTuple):
cores: int
node: str
oversubscribe: str
memory_mb: float

@classmethod
def from_line(cls, line: str) -> SlurmJob:
"""Create a SlurmJob from a squeue output line."""
user, status, nnodes, partition, cores, node, oversubscribe = line.split("/")
user, status, nnodes, partition, cores, node, oversubscribe, memory = line.split("/")
return cls(
user,
status,
Expand All @@ -943,6 +944,7 @@ def from_line(cls, line: str) -> SlurmJob:
int(cores),
node,
oversubscribe,
_parse_memory_mb(memory),
)


Expand All @@ -968,53 +970,78 @@ def get_total_cores(node_name: str) -> int:
return 0 # Return 0 if not found


ResourceMetric = Literal["cores", "nodes", "memory"]
STATUS_ROUNDING_EPSILON = 1e-9
RESOURCE_OPTION = cast(
list[str] | None,
typer.Option(
None,
"--resource",
"-r",
help="Resources to summarize (repeatable). Choose from cores, nodes, memory.",
),
)


class ResourceAggregation(NamedTuple):
"""Aggregated resource counts for users, partitions, and overall totals."""

per_user: defaultdict[str, defaultdict[str, defaultdict[str, float]]]
per_partition: defaultdict[str, defaultdict[str, float]]
totals: defaultdict[str, float]


def process_data(
output: list[SlurmJob],
cores_or_nodes: str,
) -> tuple[
defaultdict[str, defaultdict[str, defaultdict[str, int]]],
defaultdict[str, defaultdict[str, int]],
defaultdict[str, int],
]:
"""Process SLURM job data and aggregate statistics."""
data: defaultdict[str, defaultdict[str, defaultdict[str, int]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(int)),
metric: ResourceMetric,
) -> ResourceAggregation:
"""Process SLURM job data and aggregate statistics by resource metric."""
data: defaultdict[str, defaultdict[str, defaultdict[str, float]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(float)),
)
total_partition: defaultdict[str, defaultdict[str, int]] = defaultdict(
lambda: defaultdict(int),
total_partition: defaultdict[str, defaultdict[str, float]] = defaultdict(
lambda: defaultdict(float),
)
totals: defaultdict[str, int] = defaultdict(int)
totals: defaultdict[str, float] = defaultdict(float)

# Track which nodes have been counted for each user
# Track which nodes have been counted for each user when resources are exclusive
counted_nodes: defaultdict[str, set[str]] = defaultdict(set)

for s in output:
if s.oversubscribe in ["NO", "USER"]:
if metric == "memory":
value = s.memory_mb / 1024 if s.memory_mb > 0 else 0.0 # Convert to GB
elif s.oversubscribe in ["NO", "USER"]:
if s.node not in counted_nodes[s.user]:
n = get_total_cores(s.node) # Get total cores in the node
# Mark this node as counted for this user
value = float(get_total_cores(s.node))
counted_nodes[s.user].add(s.node)
else:
continue # Skip this job to prevent double-counting
continue # Skip to prevent double-counting exclusive nodes
else:
n = s.nnodes if cores_or_nodes == "nodes" else s.cores
value = float(s.nnodes) if metric == "nodes" else float(s.cores)

# Update the data structures with the correct values
data[s.user][s.partition][s.status] += n
total_partition[s.partition][s.status] += n
totals[s.status] += n
data[s.user][s.partition][s.status] += value
total_partition[s.partition][s.status] += value
totals[s.status] += value

return data, total_partition, totals
return ResourceAggregation(data, total_partition, totals)


def summarize_status(d: dict[str, int]) -> str:
def summarize_status(d: Mapping[str, float], formatter: Callable[[float], str] | None = None) -> str:
"""Summarize status dictionary into a readable string."""
return " / ".join([f"{status}={n}" for status, n in d.items()])

def _format(value: float) -> str:
if formatter is not None:
return formatter(value)
if abs(value - round(value)) < STATUS_ROUNDING_EPSILON:
return str(int(round(value)))
return f"{value:.2f}".rstrip("0").rstrip(".")

return " / ".join([f"{status}={_format(n)}" for status, n in d.items()])

def combine_statuses(d: dict[str, Any]) -> dict[str, int]:

def combine_statuses(d: dict[str, Any]) -> dict[str, float]:
"""Combine multiple status dictionaries into one."""
tot: defaultdict[str, int] = defaultdict(int)
tot: defaultdict[str, float] = defaultdict(float)
for dct in d.values():
for status, n in dct.items():
tot[status] += n
Expand Down Expand Up @@ -2626,25 +2653,64 @@ def status(
console.print(f"\n[bold]Disk Usage:[/bold] {total_size / (1024**2):.1f} MB")


def _resource_formatter(metric: ResourceMetric) -> Callable[[float], str]:
if metric == "memory":
return lambda value: f"{value:.1f} GB"
return lambda value: str(int(round(value)))


@app.command()
def current() -> None:
def current(
resources: list[str] | None = RESOURCE_OPTION,
) -> None:
"""Display current cluster usage statistics from squeue."""
output = squeue_output()
me = getuser()
for which in ["cores", "nodes"]:
data, total_partition, totals = process_data(output, which)

if isinstance(resources, list):
resources_list: list[str] | None = resources if resources else None
else:
resources_list = None

allowed_metrics: tuple[ResourceMetric, ...] = ("cores", "nodes", "memory")
seen: set[str] = set()
ordered_resources: list[ResourceMetric] = []

if resources_list is None:
ordered_resources = ["cores", "nodes"]
else:
for entry in resources_list:
normalized = entry.lower()
if normalized not in allowed_metrics:
error_message = "Invalid resource '{entry}'. Choose from cores, nodes, memory."
raise typer.BadParameter(error_message.format(entry=entry))
if normalized in seen:
continue
seen.add(normalized)
ordered_resources.append(cast(ResourceMetric, normalized))

if not ordered_resources:
ordered_resources = ["cores", "nodes"]

for which in ordered_resources:
aggregated = process_data(output, which)
data = aggregated.per_user
total_partition = aggregated.per_partition
totals = aggregated.totals
formatter = _resource_formatter(which)
table = Table(title=f"SLURM statistics [b]{which}[/]", show_footer=True)
partitions = sorted(total_partition.keys())
table.add_column("User", f"{len(data)} users", style="cyan")
for partition in partitions:
tot = summarize_status(total_partition[partition])
tot = summarize_status(total_partition[partition], formatter)
table.add_column(partition, tot, style="magenta")
table.add_column("Total", summarize_status(totals), style="magenta")
table.add_column("Total", summarize_status(totals, formatter), style="magenta")

for user, _stats in sorted(data.items()):
kw = {"style": "bold italic"} if user == me else {}
partition_stats = [summarize_status(_stats[p]) if p in _stats else "-" for p in partitions]
table.add_row(user, *partition_stats, summarize_status(combine_statuses(_stats)), **kw)
partition_stats = [summarize_status(_stats[p], formatter) if p in _stats else "-" for p in partitions]
total_summary = summarize_status(combine_statuses(_stats), formatter)
table.add_row(user, *partition_stats, total_summary, **kw)
console.print(table, justify="center")


Expand Down
2 changes: 1 addition & 1 deletion tests/snapshots/command_map.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"squeue -ro %u/%t/%D/%P/%C/%N/%h": "squeue",
"squeue -ro %u/%t/%D/%P/%C/%N/%h/%m": "squeue",
"sinfo -h -N --format='%N,%c'": "sinfo_cpus",
"sinfo -h -N --format='%N,%G'": "sinfo_gpus",
"sacct -a -S 2025-08-21T00:00:00 -E 2025-08-21T23:59:59 --format=JobID,JobIDRaw,JobName,User,UID,Group,GID,Account,Partition,QOS,State,ExitCode,Submit,Eligible,Start,End,Elapsed,ElapsedRaw,CPUTime,CPUTimeRAW,TotalCPU,UserCPU,SystemCPU,AllocCPUS,AllocNodes,NodeList,ReqCPUS,ReqMem,ReqNodes,Timelimit,TimelimitRaw,MaxRSS,MaxVMSize,MaxDiskRead,MaxDiskWrite,AveRSS,AveCPU,AveVMSize,ConsumedEnergy,ConsumedEnergyRaw,Priority,Reservation,ReservationId,WorkDir,Cluster,ReqTRES,AllocTRES,Comment,Constraints,Container,DerivedExitCode,Flags,Layout,MaxRSSNode,MaxVMSizeNode,MinCPU,NCPUS,NNodes,NTasks,Reason,SubmitLine -P -n": "sacct_day_0",
Expand Down
4 changes: 2 additions & 2 deletions tests/snapshots/squeue_output.txt
Git LFS file not shown
5 changes: 4 additions & 1 deletion tests/test_cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def test_process_data(self) -> None:
"""Test process_data function."""
jobs = slurm_usage.squeue_output()

data, total_partition, totals = slurm_usage.process_data(jobs, "cores")
aggregated = slurm_usage.process_data(jobs, "cores")
data = aggregated.per_user
total_partition = aggregated.per_partition
totals = aggregated.totals

assert isinstance(data, dict)
assert isinstance(total_partition, dict)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class TestSqueueParsing:

def test_slurm_job_from_line(self) -> None:
"""Test creating SlurmJob from squeue line."""
line = "alice/R/1/partition-01/4/node-001/OK"
line = "alice/R/1/partition-01/4/node-001/OK/8G"
job = slurm_usage.SlurmJob.from_line(line)

assert job.user == "alice"
Expand All @@ -335,6 +335,7 @@ def test_slurm_job_from_line(self) -> None:
assert job.cores == expected_cores
assert job.node == "node-001"
assert job.oversubscribe == "OK"
assert job.memory_mb == pytest.approx(8192.0)

def test_squeue_output(self) -> None:
"""Test parsing full squeue output."""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_slurm_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_run_squeue(self) -> None:
"""Test squeue command."""
result = slurm_usage.run_squeue()
assert result.returncode == 0
assert "USER/ST/NODES/PARTITION" in result.stdout
assert result.command == "squeue -ro %u/%t/%D/%P/%C/%N/%h"
assert "USER/ST/NODES/PARTITION/CPUS/NODELIST/OVER_SUBSCRIBE/MEMORY" in result.stdout
assert result.command == "squeue -ro %u/%t/%D/%P/%C/%N/%h/%m"

# Parse the output
lines = result.stdout.strip().split("\n")
Expand All @@ -40,7 +40,7 @@ def test_run_squeue(self) -> None:
# Check first data line format
if len(lines) > 1:
parts = lines[1].split("/")
expected_parts = 7 # user/status/nodes/partition/cpus/nodelist/oversubscribe
expected_parts = 8 # user/status/nodes/partition/cpus/nodelist/oversubscribe/memory
assert len(parts) == expected_parts

def test_run_sinfo_cpus(self) -> None:
Expand Down
23 changes: 15 additions & 8 deletions tests/test_slurm_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
)

# Mock squeue output
squeue_mock_output = """USER/ST/NODES/PARTITION
bas.nijholt/PD/1/mypartition-10
bas.nijholt/PD/1/mypartition-20
bas.nijholt/PD/1/mypartition-20"""
squeue_mock_output = """USER/ST/NODES/PARTITION/CPUS/NODELIST/OVER_SUBSCRIBE/MEMORY
bas.nijholt/PD/1/mypartition-10/8/node-001/OK/8000M
bas.nijholt/PD/1/mypartition-20/16/node-002/OK/16000M
bas.nijholt/PD/1/mypartition-20/16/node-003/OK/16000M"""


@pytest.fixture
Expand Down Expand Up @@ -63,11 +63,13 @@ def test_process_data() -> None:

# Create proper SlurmJob objects instead of strings
output = [
SlurmJob("user1", "R", 2, "partition1", 10, "node1", "YES"),
SlurmJob("user2", "PD", 1, "partition2", 5, "node2", "YES"),
SlurmJob("user1", "PD", 1, "partition1", 5, "node3", "YES"),
SlurmJob("user1", "R", 2, "partition1", 10, "node1", "YES", 2048.0),
SlurmJob("user2", "PD", 1, "partition2", 5, "node2", "YES", 1024.0),
SlurmJob("user1", "PD", 1, "partition1", 5, "node3", "YES", 1024.0),
]
data, total_partition, totals = process_data(output, "nodes")
aggregated_nodes = process_data(output, "nodes")
data = aggregated_nodes.per_user
totals = aggregated_nodes.totals
expected_r_count = 2
expected_pd_single = 1
expected_pd_total = 2
Expand All @@ -76,6 +78,11 @@ def test_process_data() -> None:
assert totals["PD"] == expected_pd_total
assert totals["R"] == expected_r_count

aggregated_memory = process_data(output, "memory")
totals_memory = aggregated_memory.totals
assert totals_memory["R"] == pytest.approx(2.0)
assert totals_memory["PD"] == pytest.approx(2.0)


def test_summarize_status() -> None:
"""Test summarize_status function."""
Expand Down