Skip to content

Commit b30634b

Browse files
parsshar-RHSilv3S
authored andcommitted
[Dynamo] Imporve-graph-break-skip-logs (pytorch#167067)
Fixes pytorch#150477 ### Summary: - Added frame information (function name, file, line number) to all graph break/skip messages - Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>" ### Impacts: module: dynamo Pull Request resolved: pytorch#167067 Approved by: https://github.com/williamwen42
1 parent a7f4da6 commit b30634b

File tree

7 files changed

+259
-16
lines changed

7 files changed

+259
-16
lines changed

test/dynamo/test_error_messages.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,9 @@ def fn(x):
952952
self.assertExpectedInline(
953953
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
954954
"""\
955-
Graph break: skip: from user code at:
955+
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
956+
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
957+
The graph break occurred in the following user code:
956958
File "test_error_messages.py", line N, in fn
957959
assert x is None
958960
""",
@@ -1075,9 +1077,91 @@ def gn():
10751077
File "test_error_messages.py", line N, in gn
10761078
torch._dynamo.graph_break()
10771079
1080+
""",
1081+
)
1082+
1083+
@torch._dynamo.config.patch(verbose=True)
1084+
@make_logging_test(graph_breaks=True)
1085+
def test_skipped_frame_with_verbose_traceback(self, records):
1086+
def fn(x):
1087+
with GenericCtxMgr():
1088+
torch._dynamo.graph_break()
1089+
return x + 1
1090+
1091+
torch.compile(fn, backend="eager")(torch.randn(3))
1092+
self.assertEqual(len(records), 1)
1093+
self.assertExpectedInline(
1094+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1095+
"""\
1096+
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
1097+
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
1098+
The graph break occurred in the following user code:
1099+
File "test_error_messages.py", line N, in fn
1100+
torch._dynamo.graph_break()
1101+
""",
1102+
)
1103+
self.assertExpectedInline(
1104+
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
1105+
"""\
1106+
Graph break under GenericContextWrappingVariable
1107+
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
1108+
Hint: Move the offending context manager(s) to outside the compiled region.
1109+
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
1110+
1111+
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
1112+
1113+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
1114+
1115+
from user code:
1116+
File "test_error_messages.py", line N, in fn
1117+
torch._dynamo.graph_break()
10781118
""",
10791119
)
10801120

1121+
@make_logging_test(graph_breaks=True)
1122+
def test_skip_frame_in_loop_message(self, records):
1123+
def fn(x):
1124+
for i in range(2):
1125+
with GenericCtxMgr():
1126+
if x.sum() > 0:
1127+
x = x + 1
1128+
return x
1129+
1130+
torch.compile(fn, backend="eager")(torch.randn(3))
1131+
self.assertEqual(len(records), 1)
1132+
self.assertExpectedInline(
1133+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1134+
"""\
1135+
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
1136+
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
1137+
The graph break occurred in the following user code:
1138+
File "test_error_messages.py", line N, in fn
1139+
if x.sum() > 0:
1140+
""",
1141+
)
1142+
1143+
@make_logging_test(dynamo=logging.DEBUG)
1144+
def test_skip_frame_empty_function_message(self, records):
1145+
def empty_fn(x):
1146+
pass
1147+
1148+
torch.compile(empty_fn, backend="eager")(torch.randn(3))
1149+
skip_messages = [
1150+
r
1151+
for r in records
1152+
if "intentionally decided to skip the frame" in r.getMessage()
1153+
]
1154+
self.assertEqual(len(skip_messages), 1)
1155+
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
1156+
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
1157+
1158+
self.assertExpectedInline(
1159+
msg,
1160+
"""\
1161+
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
1162+
Reason: no content in function call empty_fn test_error_messages.py N""",
1163+
)
1164+
10811165
@make_logging_test(graph_breaks=True)
10821166
def test_nested_compile_user_frames(self, records):
10831167
def fn(x):
@@ -1624,6 +1708,110 @@ def fn(x):
16241708
)
16251709

16261710

1711+
class NestedGraphBreakLoggingTests(
1712+
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
1713+
):
1714+
@make_logging_test(graph_breaks=True)
1715+
def test_skipped_frame_with_verbose_traceback_nested(self, records):
1716+
global f1, f2, f3
1717+
1718+
class GenericCtxMgr:
1719+
def __enter__(self):
1720+
return self
1721+
1722+
def __exit__(self, exc_type, exc_value, traceback):
1723+
pass
1724+
1725+
def f1(x):
1726+
with GenericCtxMgr():
1727+
torch._dynamo.graph_break()
1728+
return x + 1
1729+
1730+
def f2(x):
1731+
return f1(x + 2)
1732+
1733+
def f3(x):
1734+
return f2(x + 3)
1735+
1736+
torch.compile(f3, backend="eager")(torch.randn(3))
1737+
self.assertEqual(len(records), 1)
1738+
self.assertExpectedInline(
1739+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1740+
"""\
1741+
Graph break in user code at test_error_messages.py:N
1742+
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
1743+
Graph break under GenericContextWrappingVariable
1744+
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
1745+
Hint: Move the offending context manager(s) to outside the compiled region.
1746+
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
1747+
1748+
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
1749+
1750+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
1751+
User code traceback:
1752+
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
1753+
torch.compile(f3, backend="eager")(torch.randn(3))
1754+
File "test_error_messages.py", line N, in f3
1755+
return f2(x + 3)
1756+
File "test_error_messages.py", line N, in f2
1757+
return f1(x + 2)
1758+
File "test_error_messages.py", line N, in f1
1759+
torch._dynamo.graph_break()
1760+
""",
1761+
)
1762+
1763+
@make_logging_test(graph_breaks=True)
1764+
def test_skip_frame_in_loop_message_nested(self, records):
1765+
global f1, f2, f3
1766+
1767+
class GenericCtxMgr:
1768+
def __enter__(self):
1769+
return self
1770+
1771+
def __exit__(self, exc_type, exc_value, traceback):
1772+
pass
1773+
1774+
def f1(x):
1775+
for i in range(2):
1776+
with GenericCtxMgr():
1777+
if x.sum() > 0:
1778+
x = x + 1
1779+
return x
1780+
1781+
def f2(x):
1782+
return f1(x + 4)
1783+
1784+
def f3(x):
1785+
return f2(x + 5)
1786+
1787+
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
1788+
self.assertEqual(len(records), 1)
1789+
self.assertExpectedInline(
1790+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1791+
"""\
1792+
Graph break in user code at test_error_messages.py:N
1793+
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
1794+
Data-dependent branching
1795+
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
1796+
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
1797+
Hint: Use `torch.cond` to express dynamic control flow.
1798+
1799+
Developer debug context: attempted to jump with TensorVariable()
1800+
1801+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
1802+
User code traceback:
1803+
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
1804+
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
1805+
File "test_error_messages.py", line N, in f3
1806+
return f2(x + 5)
1807+
File "test_error_messages.py", line N, in f2
1808+
return f1(x + 4)
1809+
File "test_error_messages.py", line N, in f1
1810+
if x.sum() > 0:
1811+
""",
1812+
)
1813+
1814+
16271815
if __name__ == "__main__":
16281816
from torch._dynamo.test_case import run_tests
16291817

torch/_dynamo/convert_frame.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,7 +1870,7 @@ def __call__(
18701870
raise
18711871

18721872
soft_fail = isinstance(e, Unsupported)
1873-
1873+
code = frame.f_code
18741874
# This is a soft failure. In the sense, the code path reaches here
18751875
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
18761876
# BUILD_SET etc. In such case, we can fallback to eager without
@@ -1885,7 +1885,13 @@ def __call__(
18851885
user_stack_formatted = "".join(
18861886
traceback.format_list(user_stack)
18871887
)
1888-
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
1888+
frame_info = exc.format_frame_info(code)
1889+
user_stack_trace = (
1890+
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
1891+
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
1892+
"The graph break occurred in the following user code:\n"
1893+
f"{user_stack_formatted}"
1894+
)
18891895
torch._logging.trace_structured(
18901896
"artifact",
18911897
metadata_fn=lambda: {
@@ -1897,6 +1903,7 @@ def __call__(
18971903
graph_break_log.debug(
18981904
user_stack_trace,
18991905
exc_info=True,
1906+
stack_info=config.verbose,
19001907
)
19011908

19021909
if not config.suppress_errors and not soft_fail:

torch/_dynamo/exc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,38 @@ def format_error_msg_verbose(
794794
return msg
795795

796796

797+
def format_frame_info(code: types.CodeType) -> str:
798+
return (
799+
f"{getattr(code, 'co_name', '<unknown>')} "
800+
f"({getattr(code, 'co_filename', '<unknown>')} "
801+
f"line {getattr(code, 'co_firstlineno', 0)})"
802+
)
803+
804+
805+
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
806+
if code is not None:
807+
frame_info = format_frame_info(code)
808+
return (
809+
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
810+
f"Reason: {reason}"
811+
)
812+
else:
813+
return (
814+
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
815+
f"Reason: {reason}"
816+
)
817+
818+
819+
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
820+
frame_info = format_frame_info(code)
821+
return (
822+
"Skipping frame because there is a graph break in a for/while loop\n"
823+
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
824+
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
825+
f"{frame_summary}"
826+
)
827+
828+
797829
def format_error_msg(
798830
exc: Exception,
799831
code: types.CodeType,

torch/_dynamo/symbolic_convert.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494
BackendCompilerFailed,
9595
collapse_resume_frames,
9696
format_graph_break_message,
97+
format_loop_skip_frame_message,
98+
format_skip_frame_message,
9799
get_stack_above_dynamo,
98100
ResumePrologueTracingError,
99101
StepUnsupported,
@@ -605,9 +607,9 @@ def jump_graph_break(
605607
)
606608
# compile a partial subgraph prefix then jump into user code
607609
if self.maybe_has_backedge():
608-
msg = (
609-
"Skipping frame because there is a graph break in a for/while loop\n"
610-
f"{self.frame_summary()}"
610+
msg = format_loop_skip_frame_message(
611+
self.f_code,
612+
"".join(traceback.format_list([self.frame_summary()])),
611613
)
612614
log.info(msg)
613615
raise exc.SkipFrame(msg)
@@ -883,9 +885,9 @@ def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None:
883885
)
884886

885887
if self.maybe_has_backedge():
886-
msg = (
887-
"Skipping frame because there is a graph break in a for/while loop\n"
888-
f"{self.frame_summary()}"
888+
msg = format_loop_skip_frame_message(
889+
self.f_code,
890+
"".join(traceback.format_list([self.frame_summary()])),
889891
)
890892
log.info(msg)
891893
raise exc.SkipFrame(msg) from excp
@@ -4626,8 +4628,9 @@ def _return(self, inst: Instruction) -> None:
46264628
and not self.error_on_graph_break
46274629
and not self.is_tracing_resume_prologue
46284630
):
4629-
raise exc.SkipFrame("because no content in function call")
4630-
4631+
raise exc.SkipFrame(
4632+
format_skip_frame_message(self.f_code, "no content in function call")
4633+
)
46314634
self.instruction_pointer = None
46324635
_step_logger()(
46334636
logging.INFO,

torch/_dynamo/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,12 +2248,15 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
22482248
try:
22492249
val.data_ptr() # will throw for functorch tensors
22502250
except RuntimeError as e:
2251-
from .exc import SkipFrame
2251+
from .exc import format_skip_frame_message, SkipFrame
22522252

22532253
# This will be GradTrackingTensor/BatchedTensor/etc
22542254
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
22552255
raise SkipFrame(
2256-
f"torch.compile cannot be run in context: {functorch_subclass_name}"
2256+
format_skip_frame_message(
2257+
None,
2258+
f"torch.compile cannot be run in context: {functorch_subclass_name}",
2259+
)
22572260
) from e
22582261

22592262

torch/_dynamo/variables/functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .. import config, graph_break_hints, polyfills, variables
4343
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
4444
from ..exc import (
45+
format_skip_frame_message,
4546
get_dynamo_observed_exception,
4647
handle_observed_exception,
4748
InfiniteGeneratorError,
@@ -1652,8 +1653,13 @@ def call_function(
16521653
skip_frame_msg = kwargs.get("msg")
16531654
if skip_frame_msg:
16541655
skip_frame_msg = skip_frame_msg.as_python_constant()
1656+
else:
1657+
skip_frame_msg = ""
16551658
raise SkipFrame(
1656-
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
1659+
format_skip_frame_message(
1660+
tx.f_code,
1661+
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
1662+
)
16571663
)
16581664
elif self.value is torch._dynamo.step_unsupported:
16591665
raise StepUnsupported

torch/_logging/_internal.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,14 @@ def format(self, record):
891891
# exception handling - copied from logging.Formatter.format
892892
s = record.message
893893
if record.exc_info:
894+
from torch._dynamo import config
895+
896+
should_format_exc = config.verbose or artifact_name != "graph_breaks"
894897
# Cache the traceback text to avoid converting it multiple times
895898
# (it's constant anyway)
896-
if not record.exc_text:
897-
record.exc_text = self.formatException(record.exc_info)
899+
if should_format_exc:
900+
if not record.exc_text:
901+
record.exc_text = self.formatException(record.exc_info)
898902
if record.exc_text:
899903
if s[-1:] != "\n":
900904
s = s + "\n"

0 commit comments

Comments
 (0)