Skip to content

Commit 50d7e87

Browse files
authored
[torchlib] Mark atan2 as trace_only and map NaN to 0 (#2557)
Fix pytorch/pytorch#162570 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 710d597 commit 50d7e87

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,16 +925,21 @@ def aten_atan(self: TFloat) -> TFloat:
925925
return op.Atan(self)
926926

927927

928-
@torch_op("aten::atan2")
928+
@torch_op("aten::atan2", trace_only=True)
929929
def aten_atan2(self: TFloat, other: TFloat) -> TFloat:
930930
"""atan2(Tensor self, Tensor other) -> Tensor"""
931931

932932
# self is y, and other is x on coordinate
933933
slope = op.Div(self, other)
934934
atan = op.Atan(slope)
935+
zero = common_ops.constant(0.0, dtype=self.dtype)
936+
pi = common_ops.constant(_MATH_PI, dtype=self.dtype)
935937

936-
second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI)
937-
result = op.Where(other < 0.0, second_third_quadrant, atan)
938+
second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi)
939+
result = op.Where(op.Less(other, zero), second_third_quadrant, atan)
940+
941+
# Map NaN to 0 to match PyTorch behavior
942+
result = op.Where(op.IsNaN(result), zero, result)
938943

939944
return result
940945

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def _where_input_wrangler(
578578
TorchLibOpInfo("asin", core_ops.aten_asin),
579579
TorchLibOpInfo("asinh", core_ops.aten_asinh),
580580
TorchLibOpInfo("atan", core_ops.aten_atan),
581-
TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}),
581+
TorchLibOpInfo("atan2", core_ops.aten_atan2),
582582
TorchLibOpInfo("atanh", core_ops.aten_atanh),
583583
TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip(
584584
matcher=lambda sample: isinstance(sample.input, (list, tuple)),

0 commit comments

Comments
 (0)