Skip to content

Commit 955771e

Browse files
authored
Add validation error handling to InterpreterABC class (#575)
Add `_validation_error` field to `InterpreterABC`. ```python _validation_errors: dict[ir.IRNode, set[ir.ValidationError]] = field( default_factory=dict, init=False ) """The validation errors collected during interpretation.""" ``` With 2 API: ```python def add_validation_error(self, node: ir.IRNode, error: ir.ValidationError) -> None: """Add a ValidationError for a given IR node. If the node is not present in the _validation_errors dict, create a new set. Otherwise append to the existing set of errors. """ def get_validation_errors( self, keys: set[ir.IRNode] | None = None ) -> list[ir.ValidationError]: """Get the validation errors collected during interpretation. If keys is provided, only return errors for the given nodes. Otherwise return all errors. """ ```
1 parent 0687436 commit 955771e

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

src/kirin/interp/abc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class InterpreterABC(ABC, Generic[FrameType, ValueType]):
5151
"""The interpreter state."""
5252
__eval_lock: bool = field(default=False, init=False, repr=False)
5353
"""Lock for the eval method."""
54+
_validation_errors: dict[ir.IRNode, set[ir.ValidationError]] = field(
55+
default_factory=dict, init=False
56+
)
57+
"""The validation errors collected during interpretation."""
5458

5559
def __init_subclass__(cls) -> None:
5660
super().__init_subclass__()
@@ -330,3 +334,25 @@ def lookup_registry(
330334

331335
def build_signature(self, frame: FrameType, node: ir.Statement) -> Signature:
332336
return Signature(node.__class__, tuple(arg.type for arg in node.args))
337+
338+
def add_validation_error(self, node: ir.IRNode, error: ir.ValidationError) -> None:
339+
"""Add a ValidationError for a given IR node.
340+
341+
If the node is not present in the _validation_errors dict, create a new set.
342+
Otherwise append to the existing set of errors.
343+
"""
344+
self._validation_errors.setdefault(node, set()).add(error)
345+
346+
def get_validation_errors(
347+
self, keys: set[ir.IRNode] | None = None
348+
) -> list[ir.ValidationError]:
349+
"""Get the validation errors collected during interpretation.
350+
351+
If keys is provided, only return errors for the given nodes.
352+
Otherwise return all errors.
353+
"""
354+
if keys is None:
355+
return [err for s in self._validation_errors.values() for err in s]
356+
return [
357+
err for node in keys for err in self._validation_errors.get(node, set())
358+
]

src/kirin/ir/exception.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,15 @@ class TypeCheckError(ValidationError):
6868

6969
class CompilerError(Exception):
7070
pass
71+
72+
73+
class PotentialValidationError(ValidationError):
74+
"""Indicates a potential violation that may occur at runtime."""
75+
76+
pass
77+
78+
79+
class DefiniteValidationError(ValidationError):
80+
"""Indicates a definite violation that will occur at runtime."""
81+
82+
pass

src/kirin/validation/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .validationpass import (
2+
ValidationPass as ValidationPass,
3+
ValidationSuite as ValidationSuite,
4+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Generic, TypeVar
3+
from dataclasses import field, dataclass
4+
5+
from kirin import ir
6+
from kirin.ir.exception import ValidationError
7+
8+
T = TypeVar("T")
9+
10+
11+
class ValidationPass(ABC, Generic[T]):
12+
"""Base class for a validation pass.
13+
14+
Each pass analyzes an IR method and collects validation errors.
15+
"""
16+
17+
@abstractmethod
18+
def name(self) -> str:
19+
"""Return the name of this validation pass."""
20+
...
21+
22+
@abstractmethod
23+
def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
24+
"""Run validation and return (analysis_frame, errors).
25+
26+
Returns:
27+
- analysis_frame: The result frame from the analysis
28+
- errors: List of validation errors (empty if valid)
29+
"""
30+
...
31+
32+
def get_required_analyses(self) -> list[type]:
33+
"""Return list of analysis classes this pass depends on.
34+
35+
Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]).
36+
The suite will run these analyses once and cache results.
37+
"""
38+
return []
39+
40+
def set_analysis_cache(self, cache: dict[type, Any]) -> None:
41+
"""Receive cached analysis results from the suite.
42+
43+
Override to store cached analysis frames/results.
44+
Example:
45+
self._address_frame = cache.get(AddressAnalysis)
46+
"""
47+
pass
48+
49+
50+
@dataclass
51+
class ValidationSuite:
52+
"""Compose multiple validation passes and run them together.
53+
54+
Caches analysis results to avoid redundant computation when multiple
55+
validation passes depend on the same underlying analysis.
56+
57+
Example:
58+
suite = ValidationSuite([
59+
NoCloningValidation,
60+
AnotherValidation,
61+
])
62+
result = suite.validate(my_kernel)
63+
print(result.format_errors())
64+
"""
65+
66+
passes: list[type[ValidationPass]] = field(default_factory=list)
67+
fail_fast: bool = False
68+
_analysis_cache: dict[type, Any] = field(default_factory=dict, init=False)
69+
70+
def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite":
71+
"""Add a validation pass to the suite."""
72+
self.passes.append(pass_cls)
73+
return self
74+
75+
def validate(self, method: ir.Method) -> "ValidationResult":
76+
"""Run all validation passes and collect results."""
77+
all_errors: dict[str, list[ValidationError]] = {}
78+
all_frames: dict[str, Any] = {}
79+
self._analysis_cache.clear()
80+
81+
for pass_cls in self.passes:
82+
validator = pass_cls()
83+
pass_name = validator.name()
84+
85+
try:
86+
required = validator.get_required_analyses()
87+
for required_analysis in required:
88+
if required_analysis not in self._analysis_cache:
89+
analysis = required_analysis(method.dialects)
90+
analysis.initialize()
91+
frame, _ = analysis.run(method)
92+
self._analysis_cache[required_analysis] = frame
93+
94+
validator.set_analysis_cache(self._analysis_cache)
95+
96+
frame, errors = validator.run(method)
97+
all_frames[pass_name] = frame
98+
99+
for err in errors:
100+
if isinstance(err, ValidationError):
101+
try:
102+
err.attach(method)
103+
except Exception:
104+
pass
105+
106+
if errors:
107+
all_errors[pass_name] = errors
108+
if self.fail_fast:
109+
break
110+
111+
except Exception as e:
112+
import traceback
113+
114+
tb = traceback.format_exc()
115+
all_errors[pass_name] = [
116+
ValidationError(
117+
method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}"
118+
)
119+
]
120+
if self.fail_fast:
121+
break
122+
123+
return ValidationResult(all_errors, all_frames)
124+
125+
126+
@dataclass
127+
class ValidationResult:
128+
"""Result of running a validation suite."""
129+
130+
errors: dict[str, list[ValidationError]]
131+
frames: dict[str, Any] = field(default_factory=dict)
132+
is_valid: bool = field(default=True, init=False)
133+
134+
def __post_init__(self):
135+
for _, errors in self.errors.items():
136+
if errors:
137+
self.is_valid = False
138+
break
139+
140+
def error_count(self) -> int:
141+
"""Total number of violations across all passes.
142+
143+
Counts violations directly from frames using the same logic as test helpers.
144+
"""
145+
146+
total = 0
147+
for pass_name, errors in self.errors.items():
148+
if errors is None:
149+
continue
150+
total += len(errors)
151+
return total
152+
153+
def get_frame(self, pass_name: str) -> Any:
154+
"""Get the analysis frame for a specific pass."""
155+
return self.frames.get(pass_name)
156+
157+
def format_errors(self) -> str:
158+
"""Format all errors with their pass names."""
159+
if self.is_valid:
160+
return "\n\033[32mAll validation passes succeeded\033[0m"
161+
162+
lines = [
163+
f"\n\033[31mValidation failed with {self.error_count()} violation(s):\033[0m"
164+
]
165+
for pass_name, pass_errors in self.errors.items():
166+
lines.append(f"\n\033[31m{pass_name}:\033[0m")
167+
for err in pass_errors:
168+
err_msg = err.args[0] if err.args else str(err)
169+
lines.append(f" - {err_msg}")
170+
if hasattr(err, "hint"):
171+
hint = err.hint()
172+
if hint:
173+
lines.append(f" {hint}")
174+
175+
return "\n".join(lines)
176+
177+
def raise_if_invalid(self):
178+
"""Raise an exception if validation failed."""
179+
if not self.is_valid:
180+
first_errors = next(iter(self.errors.values()))
181+
raise first_errors[0]

0 commit comments

Comments
 (0)