diff --git a/test/xpu/test_nn_xpu.py b/test/xpu/test_nn_xpu.py index 4ff4bcef2..3e20e7eab 100644 --- a/test/xpu/test_nn_xpu.py +++ b/test/xpu/test_nn_xpu.py @@ -3132,8 +3132,17 @@ def perm_fn(x): # will result in nan. mask = torch.tensor([[1]], device=device) == 1 result = model(encoder_input, src_key_padding_mask=mask) + fast_path_device = result.is_xpu or result.is_cpu result = result.cpu().detach().numpy() - self.assertTrue(np.isnan(result).all()) + # Non Fast Paths + if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device: + # We changed the semenatic, on the non fast path so that fully masked out rows return + # 0 from attention thus NaNs should no longer be present and the output should be nonzero + # due to skip connections + self.assertTrue(not np.isnan(result).any()) + else: + # Fast Paths + self.assertTrue(np.isnan(result).all()) # deterministic input encoder_input = perm_fn(