diff --git a/aws/lambda/pytorch-auto-revert/SIGNAL_EXTRACTION.md b/aws/lambda/pytorch-auto-revert/SIGNAL_EXTRACTION.md index 6ef7314397..091b742907 100644 --- a/aws/lambda/pytorch-auto-revert/SIGNAL_EXTRACTION.md +++ b/aws/lambda/pytorch-auto-revert/SIGNAL_EXTRACTION.md @@ -115,7 +115,7 @@ Event naming (for debuggability): - Within the same run, separate events for retries via `run_attempt` (name hints like "Attempt #2" are not relied upon). ### Non‑test mapping -- Similar to test‑track but grouping is coarser (by normalized job base name): +- Similar to test‑track but grouping is coarser (by normalized job base name plus classification rule): - For each (run_id, run_attempt, job_base_name) group in the commit - Within each group compute event status: - FAILURE if any row concluded failure. diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/job_agg_index.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/job_agg_index.py index b17aee0b12..fa05614aca 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/job_agg_index.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/job_agg_index.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import ( @@ -50,13 +50,14 @@ class JobMeta: - job_id: Optional job_id from the failing job, or from the first job if none failed. """ - started_at: datetime = datetime.min - is_pending: bool = False - is_cancelled: bool = False - has_failures: bool = False all_completed_success: bool = False + rules: List[str] = field(default_factory=lambda: []) + has_failures: bool = False has_non_test_failures: bool = False + is_cancelled: bool = False + is_pending: bool = False job_id: Optional[int] = None + started_at: datetime = datetime.min @property def status(self) -> Optional[SignalStatus]: @@ -183,6 +184,7 @@ def stats(self, key: KeyT) -> JobMeta: has_non_test_failures=( any((r.is_failure and not r.is_test_failure) for r in jrows) ), + rules=[r.rule for r in jrows if r.rule], job_id=job_id, ) self._meta_cache[key] = meta diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py index cd67868680..48569de7a9 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py @@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union from .bisection_planner import GapBisectionPlanner +from .signal_extraction_types import JobBaseName, JobBaseNameRule class SignalStatus(Enum): @@ -290,6 +291,14 @@ class Signal: - job_base_name: optional job base name for job-level signals (recorded when signal is created) """ + @classmethod + def derive_base_name_with_rule( + cls, base_name: JobBaseName, rule: str | None + ) -> JobBaseNameRule: + if not rule: + return JobBaseNameRule(f"{base_name}::UNDEFINED") + return JobBaseNameRule(f"{base_name}::{rule}") + def __init__( self, key: str, diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py index 17052d4356..cd333d8541 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py @@ -7,9 +7,10 @@ Transforms raw workflow/job/test data into Signal objects used by signal.py. """ +from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple from .job_agg_index import JobAggIndex, JobMeta, SignalStatus as AggStatus from .signal import Signal, SignalCommit, SignalEvent, SignalSource, SignalStatus @@ -88,7 +89,7 @@ def extract(self) -> List[Signal]: ) test_signals = self._build_test_signals(jobs, test_rows, commits) - job_signals = self._build_non_test_signals(jobs, commits) + job_signals = self._build_job_signals(jobs, commits) # Deduplicate events within commits across all signals as a final step # GitHub-specific behavior like "rerun failed" can reuse job instances for reruns. # When that happens, the jobs have identical timestamps by DIFFERENT job ids. @@ -442,7 +443,7 @@ def _build_test_signals( return signals - def _build_non_test_signals( + def _build_job_signals( self, jobs: List[JobRow], commits: List[Tuple[Sha, datetime]] ) -> List[Signal]: """Build Signals keyed by normalized job base name per workflow. @@ -453,7 +454,6 @@ def _build_non_test_signals( jobs: List of job rows from the datasource commits: Ordered list of (sha, timestamp) tuples (newest → older) """ - commit_timestamps = dict(commits) index = JobAggIndex.from_rows( @@ -479,43 +479,123 @@ def _build_non_test_signals( signals: List[Signal] = [] for wf_name, base_name in wf_base_keys: - commit_objs: List[SignalCommit] = [] - # Track failure types across all attempts/commits for this base - has_relevant_failures = False # at least one non-test failure observed + signals += self._build_job_signals_for_wf( + commit_timestamps, wf_name, base_name, index, groups_index + ) - for sha, _ in commits: - attempt_keys: List[ - Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt] - ] = groups_index.get((sha, wf_name, base_name), []) - events: List[SignalEvent] = [] + return signals - for akey in attempt_keys: - meta = index.stats(akey) - if meta.is_cancelled: - # canceled attempts are treated as missing - continue - # Map aggregation verdict to outer SignalStatus - if meta.status is None: - continue - if meta.status == AggStatus.FAILURE and meta.has_non_test_failures: - # mark presence of non-test failures (relevant for job track) - has_relevant_failures = True - ev_status = SignalStatus.FAILURE - elif meta.status == AggStatus.PENDING: - ev_status = SignalStatus.PENDING + def _build_job_signals_for_wf( + self, + commit_timestamps: Dict[Sha, datetime], + wf_name: WorkflowName, + base_name: JobBaseName, + index: JobAggIndex[Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]], + groups_index: DefaultDict[ + Tuple[Sha, WorkflowName, JobBaseName], List[Tuple[WfRunId, RunAttempt]] + ], + ) -> List[Signal]: + # It is simpler to extract rules per signal and then build the signals, + # as it will change lots of names classes, etc + # so doing in a single iteration would be a messy code + found_rules: Set[str] = set() + for sha, _ in commit_timestamps.items(): + attempt_keys: List[ + Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt] + ] = groups_index.get((sha, wf_name, base_name), []) + + for akey in attempt_keys: + meta = index.stats(akey) + if meta.is_cancelled: + # canceled attempts are treated as missing + continue + if meta.status == AggStatus.FAILURE: + if meta.rules: + found_rules.update(r or "UNDEFINED" for r in meta.rules) else: - # Note: when all failures are caused by tests, we do NOT emit job-level failures - ev_status = SignalStatus.SUCCESS + found_rules.add("UNDEFINED") + + # we only build job signals when there are failures + if not found_rules: + return [] - # Extract wf_run_id/run_attempt from the attempt key - _, _, _, wf_run_id, run_attempt = akey + signals: Dict[str, Signal] = {} + for found_rule in found_rules: + rule_base_name = Signal.derive_base_name_with_rule( + base_name=str(base_name), rule=found_rule + ) + signals[found_rule] = Signal( + key=rule_base_name, + workflow_name=wf_name, + commits=[], + job_base_name=rule_base_name, + source=SignalSource.JOB, + ) + + for sha, _ in commit_timestamps.items(): + rule_events = self._build_job_rule_events_for_sha( + sha, wf_name, base_name, index, groups_index, found_rules + ) - events.append( + for found_rule in found_rules: + signals[found_rule].commits.append( + SignalCommit( + head_sha=sha, + timestamp=commit_timestamps[sha], + events=rule_events.get(found_rule, []), + ) + ) + + return list(signals.values()) + + def _build_job_rule_events_for_sha( + self, + sha: Sha, + wf_name: WorkflowName, + base_name: JobBaseName, + index: JobAggIndex[Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]], + groups_index: DefaultDict[ + Tuple[Sha, WorkflowName, JobBaseName], List[Tuple[WfRunId, RunAttempt]] + ], + found_rules: Set[str], + ) -> DefaultDict[str, List[SignalEvent]]: + attempt_keys: List[ + Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt] + ] = groups_index.get((sha, wf_name, base_name), []) + + rule_events: DefaultDict[str, List[SignalEvent]] = defaultdict(list) + + for akey in attempt_keys: + meta = index.stats(akey) + if meta.is_cancelled: + # canceled attempts are treated as missing + continue + # Map aggregation verdict to outer SignalStatus + if meta.status is None: + continue + if meta.status == AggStatus.FAILURE: + ev_status = SignalStatus.FAILURE + elif meta.status == AggStatus.PENDING: + ev_status = SignalStatus.PENDING + else: + ev_status = SignalStatus.SUCCESS + + # Extract wf_run_id/run_attempt from the attempt key + _, _, _, wf_run_id, run_attempt = akey + + if ev_status != SignalStatus.FAILURE: + # if the signal is not a failure it is relevant + # for all failures signals columns + for rule in found_rules: + rule_base_name = Signal.derive_base_name_with_rule( + base_name=str(base_name), rule=rule + ) + rule_events[rule].append( SignalEvent( name=self._fmt_event_name( workflow=wf_name, kind="job", - identifier=base_name, + identifier=rule_base_name, wf_run_id=wf_run_id, run_attempt=run_attempt, ), @@ -527,24 +607,34 @@ def _build_non_test_signals( job_id=meta.job_id, ) ) - - # important to always include the commit, even if no events - commit_objs.append( - SignalCommit( - head_sha=sha, timestamp=commit_timestamps[sha], events=events + else: + # signals that contain failures rules then are + # relevant to only those affected rules. + # EX: + # A signal initially failing with some timeout than fails + # with some infra error, means that the timeout signal + # status is not able to be obtained at this stage. + for rule_unfiltered in meta.rules or [None]: + rule = rule_unfiltered or "UNDEFINED" + rule_base_name = Signal.derive_base_name_with_rule( + base_name=str(base_name), rule=rule ) - ) - - # Emit job signal when failures were present and failures were NOT exclusively test-caused - if has_relevant_failures: - signals.append( - Signal( - key=base_name, - workflow_name=wf_name, - commits=commit_objs, - job_base_name=str(base_name), - source=SignalSource.JOB, + rule_events[rule].append( + SignalEvent( + name=self._fmt_event_name( + workflow=wf_name, + kind="job", + identifier=rule_base_name, + wf_run_id=wf_run_id, + run_attempt=run_attempt, + ), + status=ev_status, + started_at=meta.started_at, + ended_at=None, + wf_run_id=int(wf_run_id), + run_attempt=int(run_attempt), + job_id=meta.job_id, + ) ) - ) - return signals + return rule_events diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py index 5d8340f8f4..9f5de233c9 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py @@ -23,6 +23,7 @@ WorkflowName = NewType("WorkflowName", str) JobName = NewType("JobName", str) JobBaseName = NewType("JobBaseName", str) +JobBaseNameRule = NewType("JobBaseNameRule", str) TestId = NewType("TestId", str) diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_extraction.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_extraction.py index 4ba9e608bc..eec545a248 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_extraction.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_extraction.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta from typing import Iterable, List -from pytorch_auto_revert.signal import SignalStatus +from pytorch_auto_revert.signal import Signal, SignalStatus from pytorch_auto_revert.signal_extraction import SignalExtractor from pytorch_auto_revert.signal_extraction_datasource import SignalExtractionDatasource from pytorch_auto_revert.signal_extraction_types import ( @@ -148,8 +148,12 @@ def test_commit_order_is_stable(self): J(sha="C1", run=100, job=2, attempt=1, started_at=ts(self.t0, 5)), ] signals = self._extract(jobs, tests=[]) - base = jobs[0].base_name - sig = self._find_job_signal(signals, "trunk", base) + base = jobs[0] + sig = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule(base_name=base.base_name, rule=base.rule), + ) self.assertIsNotNone(sig) self.assertEqual([c.head_sha for c in sig.commits], ["C2", "C1"]) @@ -174,8 +178,12 @@ def test_attempt_boundary_two_events_time_ordered(self): ), ] signals = self._extract(jobs, tests=[]) - base = jobs[0].base_name - sig = self._find_job_signal(signals, "trunk", base) + base = jobs[0] + sig = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule(base_name=base.base_name, rule=""), + ) self.assertIsNotNone(sig) self.assertEqual(len(sig.commits), 1) events = sig.commits[0].events @@ -187,7 +195,7 @@ def test_attempt_boundary_two_events_time_ordered(self): self.assertIn("attempt=1", events[0].name) self.assertIn("attempt=2", events[1].name) - def test_keep_going_failure_test_track_failure_and_no_job_signal(self): + def test_keep_going_failure_test_track_failure_and_job_signal(self): # in_progress + KG-adjusted failure for a test-classified job jobs = [ J( @@ -217,8 +225,6 @@ def test_keep_going_failure_test_track_failure_and_no_job_signal(self): test_sig = self._find_test_signal(signals, "trunk", "f.py::test_a") self.assertIsNotNone(test_sig) self.assertEqual(test_sig.commits[0].events[0].status, SignalStatus.FAILURE) - # Non-test signal for this base should be omitted due to test-only failure policy - self.assertIsNone(self._find_job_signal(signals, "trunk", jobs[0].base_name)) def test_cancelled_attempt_yields_no_event(self): # Include a separate failing commit so the job signal is emitted @@ -243,125 +249,89 @@ def test_cancelled_attempt_yields_no_event(self): ), ] signals = self._extract(jobs, tests=[]) - base = jobs[0].base_name - sig = self._find_job_signal(signals, "trunk", base) + base = jobs[0] + sig = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule(base_name=base.base_name, rule=base.rule), + ) self.assertIsNotNone(sig) # find X1 commit in the signal and ensure it has no events x1 = next(c for c in sig.commits if c.head_sha == "X1") self.assertEqual(x1.events, []) - def test_non_test_inclusion_gate(self): - # (a) only test failures -> no job signal - jobs_a = [ + def test_changing_rules_across_attempts(self): + # One commit with 3 attempts: + # - attempt 1: failure with rule "infra" + # - attempt 2: failure with rule "pytest failure" + # - attempt 3: success with rule "" + jobs = [ J( - sha="A2", + sha="R1", run=600, job=40, attempt=1, - started_at=ts(self.t0, 10), + started_at=ts(self.t0, 1), conclusion="failure", - rule="pytest failure", + rule="infra", ), J( - sha="A1", - run=610, - job=41, - attempt=1, - started_at=ts(self.t0, 5), - conclusion="failure", - rule="pytest failure", - ), - ] - tests_a = [ - T( - job=40, + sha="R1", run=600, - attempt=1, - file="f.py", - name="test_x", - failure_runs=1, - success_runs=0, - ), - T( job=41, - run=610, - attempt=1, - file="f.py", - name="test_x", - failure_runs=1, - success_runs=0, - ), - ] - signals_a = self._extract(jobs_a, tests_a) - self.assertIsNone( - self._find_job_signal(signals_a, "trunk", jobs_a[0].base_name) - ) - - # (b) includes a non-test failure -> job signal emitted - jobs_b = [ - J( - sha="B2", - run=700, - job=50, - attempt=1, - started_at=ts(self.t0, 10), + attempt=2, + started_at=ts(self.t0, 2), conclusion="failure", - rule="infra-flake", # non-test classification + rule="pytest failure", ), J( - sha="B1", - run=710, - job=51, - attempt=1, - started_at=ts(self.t0, 5), + sha="R1", + run=600, + job=42, + attempt=3, + started_at=ts(self.t0, 3), conclusion="success", rule="", ), ] - signals_b = self._extract(jobs_b, tests=[]) - self.assertIsNotNone( - self._find_job_signal(signals_b, "trunk", jobs_b[0].base_name) - ) - - def test_job_track_treats_test_failures_as_success(self): - # When a base has a non-test (infra) failure somewhere (so a job signal is emitted), - # attempts that fail due to tests should NOT appear as FAILURES in the job track. - # They should be treated as SUCCESS at the job-track level, leaving the failure to test-track. - jobs = [ - # Newer commit: infra-caused failure (non-test classification) - J( - sha="Z2", - run=9100, - job=801, - attempt=1, - started_at=ts(self.t0, 20), - conclusion="failure", - rule="infra", # non-test - ), - # Older commit: failure caused by tests (test classification) - J( - sha="Z1", - run=9000, - job=800, - attempt=1, - started_at=ts(self.t0, 10), - conclusion="failure", - rule="pytest failure", # test-caused - ), - ] - signals = self._extract(jobs, tests=[]) - base = jobs[0].base_name - sig = self._find_job_signal(signals, "trunk", base) + # for 'infra' rule + base = jobs[0] + sig = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule(base_name=base.base_name, rule=base.rule), + ) self.assertIsNotNone(sig) - # Expect commits newest->older - self.assertEqual([c.head_sha for c in sig.commits], ["Z2", "Z1"]) - # Newer infra failure remains FAILURE - self.assertEqual(len(sig.commits[0].events), 1) - self.assertEqual(sig.commits[0].events[0].status, SignalStatus.FAILURE) - # Older test-caused failure is mapped to SUCCESS in job track - self.assertEqual(len(sig.commits[1].events), 1) - self.assertEqual(sig.commits[1].events[0].status, SignalStatus.SUCCESS) + self.assertEqual(len(sig.commits), 1) + events = sig.commits[0].events + self.assertEqual(len(events), 2) + # the first infra failure + self.assertEqual(events[0].status, SignalStatus.FAILURE) + self.assertIn(base.rule, events[0].name) + # the final success + self.assertEqual(events[1].status, SignalStatus.SUCCESS) + self.assertIn(base.rule, events[1].name) + + # for 'pytest failure' rule + base2 = jobs[1] + sig2 = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule( + base_name=base2.base_name, rule=base2.rule + ), + ) + self.assertIsNotNone(sig2) + self.assertEqual(len(sig2.commits), 1) + events2 = sig2.commits[0].events + self.assertEqual(len(events2), 2) + # the pytest failure + self.assertEqual(events2[0].status, SignalStatus.FAILURE) + self.assertIn(base2.rule, events2[0].name) + # the final success + self.assertEqual(events2[1].status, SignalStatus.SUCCESS) + self.assertIn(base2.rule, events2[1].name) def test_commits_without_jobs_are_included(self): # Verify that commits with no jobs at all are still included in signals @@ -406,8 +376,12 @@ def fetch_commits_in_time_range( se._datasource = FakeDatasourceWithExtraCommit(jobs, []) signals = se.extract() - base = jobs[0].base_name - sig = self._find_job_signal(signals, "trunk", base) + base = jobs[0] + sig = self._find_job_signal( + signals, + "trunk", + Signal.derive_base_name_with_rule(base_name=base.base_name, rule=base.rule), + ) self.assertIsNotNone(sig) # Should have 3 commits: C2 (with events), C3 (no events), C1 (with events) self.assertEqual(len(sig.commits), 3)