@@ -456,6 +456,20 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p):
456456 assert mask_specgrams .size () == specgrams .size ()
457457 assert (num_masked_columns < mask_param ).sum () == num_masked_columns .numel ()
458458
459+ @parameterized .expand (list (itertools .product ([100 ], [0.0 , 30.0 ], [2 , 3 ], [0.2 , 1.0 ])))
460+ def test_mask_along_axis_iid_mask_value (self , mask_param , mask_value , axis , p ):
461+ specgrams = torch .randn (4 , 2 , 1025 , 400 , dtype = self .dtype , device = self .device )
462+ mask_value_tensor = torch .tensor (mask_value , dtype = self .dtype , device = self .device )
463+ torch .manual_seed (0 )
464+ # as this operation is random we need to fix the seed for results to match
465+ mask_specgrams = F .mask_along_axis_iid (specgrams , mask_param , mask_value_tensor , axis , p = p )
466+ torch .manual_seed (0 )
467+ mask_specgrams_float = F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis , p = p )
468+ assert torch .allclose (
469+ mask_specgrams , mask_specgrams_float
470+ ), f"""Masking with float and tensor should be the same diff = {
471+ torch .abs (mask_specgrams - mask_specgrams_float ).max ()} """
472+
459473 @parameterized .expand (list (itertools .product ([(2 , 1025 , 400 ), (1 , 201 , 100 )], [100 ], [0.0 , 30.0 ], [1 , 2 ])))
460474 def test_mask_along_axis_preserve (self , shape , mask_param , mask_value , axis ):
461475 """mask_along_axis should not alter original input Tensor
0 commit comments