Fix TransformerEncoderLayer Full Mask UT Failure on XPU #2336
+10
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The Problem Solved
The nn.TransformerEncoderLayer unit test failed on XPU devices when checking the expected output for a fully masked input (Sequence Length=1, Mask=[[True]]). #2015
Root Cause:
Due to a missing device entry in the core PyTorch's Fast Path whitelist, XPU execution was incorrectly forced into the slow Python Fallback path, instead of the optimized C++ fused kernel. pytorch/pytorch#168234
Semantic Mismatch:
The XPU implementation executing the Fallback path produces a non-NaN finite value for fully masked inputs (reflecting the robust mathematical result of X+Attention(0)).
Assertion Error:
However, the unit test's internal logic erroneously categorized the XPU execution as belonging to a Fast Path, which expected a NaN result. The actual non-NaN finite result did not match the expected NaN assertion, causing the test failure.
Fix Implementation
By modifying the internal UT logic to explicitly track the device's correct execution semantics (Non-NaN for XPU), we switch the test's assertion from the incorrect self.assertTrue(np.isnan(result).all()) to the correct self.assertTrue(not np.isnan(result).any()).