diff --git a/tests/fixtures/misc/checker/rsqrt.py b/tests/fixtures/misc/checker/rsqrt.py new file mode 100644 index 0000000..d596e49 --- /dev/null +++ b/tests/fixtures/misc/checker/rsqrt.py @@ -0,0 +1,10 @@ +import torch + + +a = torch.randn(5) +b = 1 / torch.sqrt(a) +b = 1.0 / torch.sqrt(a) +b = a / torch.sqrt(a) +# False negative +b = 1 / a.sqrt() +b = 1.0 / a.sqrt() diff --git a/tests/fixtures/misc/checker/rsqrt.txt b/tests/fixtures/misc/checker/rsqrt.txt new file mode 100644 index 0000000..1e480be --- /dev/null +++ b/tests/fixtures/misc/checker/rsqrt.txt @@ -0,0 +1,3 @@ +5:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. +6:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. +7:5 TOR109 Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 5baa12a..b008ecd 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -48,6 +48,7 @@ def pytest_generate_tests(metafunc): "TOR106", "TOR107", "TOR108", + "TOR109", }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5e96e38..3968213 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -18,6 +18,7 @@ TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, TorchUnsafeLoadVisitor, + TorchRsqrtVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, @@ -35,6 +36,7 @@ TorchExpm1Visitor, TorchLog1pVisitor, TorchLogsumexpVisitor, + TorchRsqrtVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, TorchReentrantCheckpointVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 45f2438..0a488f6 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -6,6 +6,7 @@ TorchLogsumexpVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, + TorchRsqrtVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor from .performance import ( @@ -30,6 +31,7 @@ "TorchScopedLibraryVisitor", "TorchSynchronizedDataLoaderVisitor", "TorchUnsafeLoadVisitor", + "TorchRsqrtVisitor", "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 8f0c70c..e5dafcb 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,7 +184,6 @@ def visit_Call(self, node): ) == "torch.exp" ): - # if `dim` is not provided or None for sum, skip: # https://github.com/pytorch/pytorch/issues/144339 dim_arg = self.get_specific_arg( @@ -201,3 +200,29 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchRsqrtVisitor(TorchVisitor): + """ + Suggest using `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`. + """ + + ERRORS = [ + TorchError( + "TOR109", + ("Consider faster `a*torch.rsqrt(b)` instead of `a/torch.sqrt(b)`."), + ) + ] + + def visit_BinaryOperation(self, node): + if m.matches( + node, + m.BinaryOperation(operator=m.Divide(), right=m.Call()), + ): + if self.get_qualified_name_for_call(node.right) == "torch.sqrt": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )