@@ -96,7 +96,6 @@ class TorchLog1pVisitor(TorchVisitor):
9696
9797 def visit_Call (self , node ):
9898 if self .get_qualified_name_for_call (node ) == "torch.log" :
99-
10099 if m .matches (
101100 node ,
102101 m .Call (
@@ -114,7 +113,6 @@ def visit_Call(self, node):
114113 ],
115114 ),
116115 ):
117-
118116 self .add_violation (
119117 node ,
120118 error_code = self .ERRORS [0 ].error_code ,
@@ -154,3 +152,41 @@ def visit_BinaryOperation(self, node):
154152 message = self .ERRORS [0 ].message (),
155153 replacement = None ,
156154 )
155+
156+
157+ class TorchLogsumexpVisitor (TorchVisitor ):
158+ """
159+ Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`.
160+ """
161+
162+ ERRORS = [
163+ TorchError (
164+ "TOR108" ,
165+ ("Use numerically stabilized `torch.logsumexp`." ),
166+ )
167+ ]
168+
169+ def visit_Call (self , node ):
170+ if self .get_qualified_name_for_call (node ) == "torch.log" :
171+ if m .matches (
172+ node ,
173+ m .Call (
174+ args = [
175+ m .Arg (m .Call (args = [m .Arg (m .Call ()), m .ZeroOrMore ()])),
176+ m .ZeroOrMore (),
177+ ]
178+ ),
179+ ):
180+ if self .get_qualified_name_for_call (node .args [0 ].value ) == "torch.sum" :
181+ if (
182+ self .get_qualified_name_for_call (
183+ node .args [0 ].value .args [0 ].value
184+ )
185+ == "torch.exp"
186+ ):
187+ self .add_violation (
188+ node ,
189+ error_code = self .ERRORS [0 ].error_code ,
190+ message = self .ERRORS [0 ].message (),
191+ replacement = None ,
192+ )
0 commit comments