File tree Expand file tree Collapse file tree 6 files changed +44
-1
lines changed Expand file tree Collapse file tree 6 files changed +44
-1
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+
3+
4+ a = torch .randn (5 )
5+ b = 1 / torch .sqrt (a )
6+ b = 1.0 / torch .sqrt (a )
7+ b = a / torch .sqrt (a )
8+ # False negative
9+ b = 1 / a .sqrt ()
10+ b = 1.0 / a .sqrt ()
Original file line number Diff line number Diff line change 1+ 5:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
2+ 6:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
3+ 7:5 TOR109 Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster.
Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ def pytest_generate_tests(metafunc):
4848 "TOR106" ,
4949 "TOR107" ,
5050 "TOR108" ,
51+ "TOR109" ,
5152 },
5253 ),
5354 (None , set (GET_ALL_ERROR_CODES ()) - exclude_set ),
Original file line number Diff line number Diff line change 1818 TorchScopedLibraryVisitor ,
1919 TorchSynchronizedDataLoaderVisitor ,
2020 TorchUnsafeLoadVisitor ,
21+ TorchRsqrtVisitor ,
2122 TorchVisionDeprecatedPretrainedVisitor ,
2223 TorchVisionDeprecatedToTensorVisitor ,
2324 TorchVisionSingletonImportVisitor ,
3536 TorchExpm1Visitor ,
3637 TorchLog1pVisitor ,
3738 TorchLogsumexpVisitor ,
39+ TorchRsqrtVisitor ,
3840 TorchNonPublicAliasVisitor ,
3941 TorchRequireGradVisitor ,
4042 TorchReentrantCheckpointVisitor ,
Original file line number Diff line number Diff line change 66 TorchLogsumexpVisitor ,
77 TorchReentrantCheckpointVisitor ,
88 TorchRequireGradVisitor ,
9+ TorchRsqrtVisitor ,
910)
1011from .nonpublic import TorchNonPublicAliasVisitor
1112from .performance import (
3031 "TorchScopedLibraryVisitor" ,
3132 "TorchSynchronizedDataLoaderVisitor" ,
3233 "TorchUnsafeLoadVisitor" ,
34+ "TorchRsqrtVisitor" ,
3335 "TorchVisionDeprecatedPretrainedVisitor" ,
3436 "TorchVisionDeprecatedToTensorVisitor" ,
3537 "TorchVisionSingletonImportVisitor" ,
Original file line number Diff line number Diff line change @@ -184,7 +184,6 @@ def visit_Call(self, node):
184184 )
185185 == "torch.exp"
186186 ):
187-
188187 # if `dim` is not provided or None for sum, skip:
189188 # https://github.com/pytorch/pytorch/issues/144339
190189 dim_arg = self .get_specific_arg (
@@ -201,3 +200,29 @@ def visit_Call(self, node):
201200 message = self .ERRORS [0 ].message (),
202201 replacement = None ,
203202 )
203+
204+
205+ class TorchRsqrtVisitor (TorchVisitor ):
206+ """
207+ Suggest using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`.
208+ """
209+
210+ ERRORS = [
211+ TorchError (
212+ "TOR109" ,
213+ ("Consider using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`, which is faster." ),
214+ )
215+ ]
216+
217+ def visit_BinaryOperation (self , node ):
218+ if m .matches (
219+ node ,
220+ m .BinaryOperation (operator = m .Divide (), right = m .Call ()),
221+ ):
222+ if self .get_qualified_name_for_call (node .right ) == "torch.sqrt" :
223+ self .add_violation (
224+ node ,
225+ error_code = self .ERRORS [0 ].error_code ,
226+ message = self .ERRORS [0 ].message (),
227+ replacement = None ,
228+ )
You can’t perform that action at this time.
0 commit comments