Skip to content

Commit 7371b2e

Browse files
strenghen condition for round trip test
1 parent 54584dd commit 7371b2e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6801,13 +6801,15 @@ def test_round_trip(self, dtype):
68016801
original_tensor = torch.rand(3, 4, 4, dtype=dtype, device="cuda")
68026802

68036803
# Execute: Convert to NVCV and back to tensor
6804-
# CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> CHW
6804+
# CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW
68056805
nvcv_tensor = F.to_nvcv_tensor(original_tensor)
68066806
result_tensor = F.nvcv_to_tensor(nvcv_tensor)
68076807

6808-
# Assert: The round-trip conversion preserves the original tensor
6809-
# Use allclose for robust comparison that handles floating-point precision
6810-
assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7)
6808+
# Remove batch dimension that was added during conversion since original was unbatched
6809+
result_tensor = result_tensor.squeeze(0)
6810+
6811+
# Assert: The round-trip conversion preserves the original tensor exactly
6812+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
68116813

68126814
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
68136815
@pytest.mark.parametrize("batch_size", [1, 2, 4])
@@ -6823,9 +6825,8 @@ def test_round_trip_batched(self, dtype, batch_size):
68236825
nvcv_tensor = F.to_nvcv_tensor(original_tensor)
68246826
result_tensor = F.nvcv_to_tensor(nvcv_tensor)
68256827

6826-
# Assert: The round-trip conversion preserves the original batched tensor
6827-
# Use allclose for robust comparison that handles floating-point precision
6828-
assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7)
6828+
# Assert: The round-trip conversion preserves the original batched tensor exactly
6829+
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
68296830
# Also verify batch size is preserved
68306831
assert result_tensor.shape[0] == batch_size
68316832

0 commit comments

Comments
 (0)