Skip to content

Commit 28f1a5f

Browse files
authored
Add TorchLogsumexpVisitor (#89)
Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`. https://pytorch.org/docs/stable/generated/torch.logsumexp.html
1 parent 4ff3caf commit 28f1a5f

File tree

6 files changed

+59
-2
lines changed

6 files changed

+59
-2
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
a = torch.randn(5)
3+
b = torch.randn(5)
4+
5+
# logsumexp
6+
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True))
7+
y = torch.log(torch.sum(torch.exp(2.5 + x), 1))
8+
9+
# not logsumexp
10+
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5)
11+
y = torch.log(torch.sum(torch.exp(x) + 2.5, 1))
12+
y = torch.log(2 + x)
13+
y = torch.sum(torch.log(torch.exp(x)), 1)
14+
y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
6:5 TOR108 Use numerically stabilized `torch.logsumexp`.
2+
7:5 TOR108 Use numerically stabilized `torch.logsumexp`.

tests/test_torchfix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def pytest_generate_tests(metafunc):
4747
"TOR105",
4848
"TOR106",
4949
"TOR107",
50+
"TOR108",
5051
},
5152
),
5253
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TorchDeprecatedSymbolsVisitor,
1212
TorchExpm1Visitor,
1313
TorchLog1pVisitor,
14+
TorchLogsumexpVisitor,
1415
TorchNonPublicAliasVisitor,
1516
TorchReentrantCheckpointVisitor,
1617
TorchRequireGradVisitor,
@@ -32,6 +33,7 @@
3233
TorchDeprecatedSymbolsVisitor,
3334
TorchExpm1Visitor,
3435
TorchLog1pVisitor,
36+
TorchLogsumexpVisitor,
3537
TorchNonPublicAliasVisitor,
3638
TorchRequireGradVisitor,
3739
TorchReentrantCheckpointVisitor,

torchfix/visitors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .misc import (
44
TorchExpm1Visitor,
55
TorchLog1pVisitor,
6+
TorchLogsumexpVisitor,
67
TorchReentrantCheckpointVisitor,
78
TorchRequireGradVisitor,
89
)
@@ -19,6 +20,7 @@
1920
"TorchDeprecatedSymbolsVisitor",
2021
"TorchExpm1Visitor",
2122
"TorchLog1pVisitor",
23+
"TorchLogsumexpVisitor",
2224
"TorchNonPublicAliasVisitor",
2325
"TorchReentrantCheckpointVisitor",
2426
"TorchRequireGradVisitor",

torchfix/visitors/misc/__init__.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class TorchLog1pVisitor(TorchVisitor):
9696

9797
def visit_Call(self, node):
9898
if self.get_qualified_name_for_call(node) == "torch.log":
99-
10099
if m.matches(
101100
node,
102101
m.Call(
@@ -114,7 +113,6 @@ def visit_Call(self, node):
114113
],
115114
),
116115
):
117-
118116
self.add_violation(
119117
node,
120118
error_code=self.ERRORS[0].error_code,
@@ -154,3 +152,41 @@ def visit_BinaryOperation(self, node):
154152
message=self.ERRORS[0].message(),
155153
replacement=None,
156154
)
155+
156+
157+
class TorchLogsumexpVisitor(TorchVisitor):
158+
"""
159+
Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`.
160+
"""
161+
162+
ERRORS = [
163+
TorchError(
164+
"TOR108",
165+
("Use numerically stabilized `torch.logsumexp`."),
166+
)
167+
]
168+
169+
def visit_Call(self, node):
170+
if self.get_qualified_name_for_call(node) == "torch.log":
171+
if m.matches(
172+
node,
173+
m.Call(
174+
args=[
175+
m.Arg(m.Call(args=[m.Arg(m.Call()), m.ZeroOrMore()])),
176+
m.ZeroOrMore(),
177+
]
178+
),
179+
):
180+
if self.get_qualified_name_for_call(node.args[0].value) == "torch.sum":
181+
if (
182+
self.get_qualified_name_for_call(
183+
node.args[0].value.args[0].value
184+
)
185+
== "torch.exp"
186+
):
187+
self.add_violation(
188+
node,
189+
error_code=self.ERRORS[0].error_code,
190+
message=self.ERRORS[0].message(),
191+
replacement=None,
192+
)

0 commit comments

Comments
 (0)