From d6eccaa50475cdcf06175c92b38419efc04258b6 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sun, 26 Oct 2025 19:38:51 +0530 Subject: [PATCH] Fix attention mask to use float_lowest instead of -inf and add unit test for softmax NaN case --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +++++- tests/common/testutils.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..65bb2aa079 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -14,6 +14,7 @@ from __future__ import annotations +import numpy as np import math from typing import Optional, Sequence, Tuple, TypeVar, Union @@ -2048,6 +2049,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) +def float_lowest(dtype): + """Returns the lowest representable value for the given numpy dtype.""" + return np.finfo(np.dtype(dtype)).min def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, @@ -2078,7 +2082,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 2a2697b240..1db673eab8 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,6 +14,7 @@ import torch from onnxscript import optimizer +from onnxscript.onnx_opset import opset18 as op from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -101,3 +102,9 @@ def test_onnxruntime_rewrite( f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" ) raise + +def test_softmax_with_all_inf_mask(): + # GH #2561 + input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32) + output = op.Softmax(input, axis=-1) + assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf"