Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws/lambda/pytorch-auto-revert/SIGNAL_EXTRACTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Event naming (for debuggability):
- Within the same run, separate events for retries via `run_attempt` (name hints like "Attempt #2" are not relied upon).

### Non‑test mapping
- Similar to test‑track but grouping is coarser (by normalized job base name):
- Similar to test‑track but grouping is coarser (by normalized job base name plus classification rule):
- For each (run_id, run_attempt, job_base_name) group in the commit
- Within each group compute event status:
- FAILURE if any row concluded failure.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import (
Expand Down Expand Up @@ -50,13 +50,14 @@ class JobMeta:
- job_id: Optional job_id from the failing job, or from the first job if none failed.
"""

started_at: datetime = datetime.min
is_pending: bool = False
is_cancelled: bool = False
has_failures: bool = False
all_completed_success: bool = False
rules: List[str] = field(default_factory=lambda: [])
has_failures: bool = False
has_non_test_failures: bool = False
is_cancelled: bool = False
is_pending: bool = False
job_id: Optional[int] = None
started_at: datetime = datetime.min

@property
def status(self) -> Optional[SignalStatus]:
Expand Down Expand Up @@ -183,6 +184,7 @@ def stats(self, key: KeyT) -> JobMeta:
has_non_test_failures=(
any((r.is_failure and not r.is_test_failure) for r in jrows)
),
rules=[r.rule for r in jrows if r.rule],
job_id=job_id,
)
self._meta_cache[key] = meta
Expand Down
9 changes: 9 additions & 0 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, List, Optional, Set, Tuple, Union

from .bisection_planner import GapBisectionPlanner
from .signal_extraction_types import JobBaseName, JobBaseNameRule


class SignalStatus(Enum):
Expand Down Expand Up @@ -290,6 +291,14 @@ class Signal:
- job_base_name: optional job base name for job-level signals (recorded when signal is created)
"""

@classmethod
def derive_base_name_with_rule(
cls, base_name: JobBaseName, rule: str | None
) -> JobBaseNameRule:
if not rule:
return JobBaseNameRule(f"{base_name}::UNDEFINED")
return JobBaseNameRule(f"{base_name}::{rule}")

def __init__(
self,
key: str,
Expand Down
190 changes: 140 additions & 50 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
Transforms raw workflow/job/test data into Signal objects used by signal.py.
"""

from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple

from .job_agg_index import JobAggIndex, JobMeta, SignalStatus as AggStatus
from .signal import Signal, SignalCommit, SignalEvent, SignalSource, SignalStatus
Expand Down Expand Up @@ -88,7 +89,7 @@ def extract(self) -> List[Signal]:
)

test_signals = self._build_test_signals(jobs, test_rows, commits)
job_signals = self._build_non_test_signals(jobs, commits)
job_signals = self._build_job_signals(jobs, commits)
# Deduplicate events within commits across all signals as a final step
# GitHub-specific behavior like "rerun failed" can reuse job instances for reruns.
# When that happens, the jobs have identical timestamps by DIFFERENT job ids.
Expand Down Expand Up @@ -442,7 +443,7 @@ def _build_test_signals(

return signals

def _build_non_test_signals(
def _build_job_signals(
self, jobs: List[JobRow], commits: List[Tuple[Sha, datetime]]
) -> List[Signal]:
"""Build Signals keyed by normalized job base name per workflow.
Expand All @@ -453,7 +454,6 @@ def _build_non_test_signals(
jobs: List of job rows from the datasource
commits: Ordered list of (sha, timestamp) tuples (newest → older)
"""

commit_timestamps = dict(commits)

index = JobAggIndex.from_rows(
Expand All @@ -479,43 +479,123 @@ def _build_non_test_signals(

signals: List[Signal] = []
for wf_name, base_name in wf_base_keys:
commit_objs: List[SignalCommit] = []
# Track failure types across all attempts/commits for this base
has_relevant_failures = False # at least one non-test failure observed
signals += self._build_job_signals_for_wf(
commit_timestamps, wf_name, base_name, index, groups_index
)

for sha, _ in commits:
attempt_keys: List[
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
] = groups_index.get((sha, wf_name, base_name), [])
events: List[SignalEvent] = []
return signals

for akey in attempt_keys:
meta = index.stats(akey)
if meta.is_cancelled:
# canceled attempts are treated as missing
continue
# Map aggregation verdict to outer SignalStatus
if meta.status is None:
continue
if meta.status == AggStatus.FAILURE and meta.has_non_test_failures:
# mark presence of non-test failures (relevant for job track)
has_relevant_failures = True
ev_status = SignalStatus.FAILURE
elif meta.status == AggStatus.PENDING:
ev_status = SignalStatus.PENDING
def _build_job_signals_for_wf(
self,
commit_timestamps: Dict[Sha, datetime],
wf_name: WorkflowName,
base_name: JobBaseName,
index: JobAggIndex[Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]],
groups_index: DefaultDict[
Tuple[Sha, WorkflowName, JobBaseName], List[Tuple[WfRunId, RunAttempt]]
],
) -> List[Signal]:
# It is simpler to extract rules per signal and then build the signals,
# as it will change lots of names classes, etc
# so doing in a single iteration would be a messy code
found_rules: Set[str] = set()
for sha, _ in commit_timestamps.items():
attempt_keys: List[
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
] = groups_index.get((sha, wf_name, base_name), [])

for akey in attempt_keys:
meta = index.stats(akey)
if meta.is_cancelled:
# canceled attempts are treated as missing
continue
if meta.status == AggStatus.FAILURE:
if meta.rules:
found_rules.update(r or "UNDEFINED" for r in meta.rules)
else:
# Note: when all failures are caused by tests, we do NOT emit job-level failures
ev_status = SignalStatus.SUCCESS
found_rules.add("UNDEFINED")

# we only build job signals when there are failures
if not found_rules:
return []

# Extract wf_run_id/run_attempt from the attempt key
_, _, _, wf_run_id, run_attempt = akey
signals: Dict[str, Signal] = {}
for found_rule in found_rules:
rule_base_name = Signal.derive_base_name_with_rule(
base_name=str(base_name), rule=found_rule
)
signals[found_rule] = Signal(
key=rule_base_name,
workflow_name=wf_name,
commits=[],
job_base_name=rule_base_name,
source=SignalSource.JOB,
)

for sha, _ in commit_timestamps.items():
rule_events = self._build_job_rule_events_for_sha(
sha, wf_name, base_name, index, groups_index, found_rules
)

events.append(
for found_rule in found_rules:
signals[found_rule].commits.append(
SignalCommit(
head_sha=sha,
timestamp=commit_timestamps[sha],
events=rule_events.get(found_rule, []),
)
)

return list(signals.values())

def _build_job_rule_events_for_sha(
self,
sha: Sha,
wf_name: WorkflowName,
base_name: JobBaseName,
index: JobAggIndex[Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]],
groups_index: DefaultDict[
Tuple[Sha, WorkflowName, JobBaseName], List[Tuple[WfRunId, RunAttempt]]
],
found_rules: Set[str],
) -> DefaultDict[str, List[SignalEvent]]:
attempt_keys: List[
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
] = groups_index.get((sha, wf_name, base_name), [])

rule_events: DefaultDict[str, List[SignalEvent]] = defaultdict(list)

for akey in attempt_keys:
meta = index.stats(akey)
if meta.is_cancelled:
# canceled attempts are treated as missing
continue
# Map aggregation verdict to outer SignalStatus
if meta.status is None:
continue
if meta.status == AggStatus.FAILURE:
ev_status = SignalStatus.FAILURE
elif meta.status == AggStatus.PENDING:
ev_status = SignalStatus.PENDING
else:
ev_status = SignalStatus.SUCCESS

# Extract wf_run_id/run_attempt from the attempt key
_, _, _, wf_run_id, run_attempt = akey

if ev_status != SignalStatus.FAILURE:
# if the signal is not a failure it is relevant
# for all failures signals columns
for rule in found_rules:
rule_base_name = Signal.derive_base_name_with_rule(
base_name=str(base_name), rule=rule
)
rule_events[rule].append(
SignalEvent(
name=self._fmt_event_name(
workflow=wf_name,
kind="job",
identifier=base_name,
identifier=rule_base_name,
wf_run_id=wf_run_id,
run_attempt=run_attempt,
),
Expand All @@ -527,24 +607,34 @@ def _build_non_test_signals(
job_id=meta.job_id,
)
)

# important to always include the commit, even if no events
commit_objs.append(
SignalCommit(
head_sha=sha, timestamp=commit_timestamps[sha], events=events
else:
# signals that contain failures rules then are
# relevant to only those affected rules.
# EX:
# A signal initially failing with some timeout than fails
# with some infra error, means that the timeout signal
# status is not able to be obtained at this stage.
for rule_unfiltered in meta.rules or [None]:
rule = rule_unfiltered or "UNDEFINED"
rule_base_name = Signal.derive_base_name_with_rule(
base_name=str(base_name), rule=rule
)
)

# Emit job signal when failures were present and failures were NOT exclusively test-caused
if has_relevant_failures:
signals.append(
Signal(
key=base_name,
workflow_name=wf_name,
commits=commit_objs,
job_base_name=str(base_name),
source=SignalSource.JOB,
rule_events[rule].append(
SignalEvent(
name=self._fmt_event_name(
workflow=wf_name,
kind="job",
identifier=rule_base_name,
wf_run_id=wf_run_id,
run_attempt=run_attempt,
),
status=ev_status,
started_at=meta.started_at,
ended_at=None,
wf_run_id=int(wf_run_id),
run_attempt=int(run_attempt),
job_id=meta.job_id,
)
)
)

return signals
return rule_events
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
WorkflowName = NewType("WorkflowName", str)
JobName = NewType("JobName", str)
JobBaseName = NewType("JobBaseName", str)
JobBaseNameRule = NewType("JobBaseNameRule", str)
TestId = NewType("TestId", str)


Expand Down
Loading