diff --git a/deeptrack/math.py b/deeptrack/math.py index eb9be8040..ff43577a0 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -1069,8 +1069,7 @@ def __init__( super().__init__(ndimage.median_filter, size=ksize, **kwargs) -#TODO ***AL*** revise Pool - torch, typing, docstring, unit test -class Pool(Feature): +class Pool(Feature): # Deprecated, children will be independent in the future. """Downsamples the image by applying a function to local regions of the image. @@ -1078,15 +1077,15 @@ class Pool(Feature): non-overlapping blocks of size `ksize` and applying the specified pooling function to each block. The result is a downsampled image where each pixel value represents the result of the pooling function applied to the - corresponding block. + corresponding block. This pooling only works with numpy functions. Parameters ---------- - pooling_function: function + pooling_function: Numpy function A function that is applied to each local region of the image. DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION. - The `pooling_function` must accept the input image as a keyword argument - named `input`, as it is called via `utils.safe_call`. + The `pooling_function` must accept the input image as a keyword + argument named `input`, as it is called via `utils.safe_call`. Examples include `np.mean`, `np.max`, `np.min`, etc. ksize: int Size of the pooling kernel. @@ -1095,7 +1094,8 @@ class Pool(Feature): Methods ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` + `get(image: NDArray, + ksize: int, **kwargs: Any) --> NDArray` Applies the pooling function to the input image. Examples @@ -1152,17 +1152,17 @@ def __init__( def get( self: Pool, - image: np.ndarray | Image, + image: NDArray, ksize: int, **kwargs: Any, - ) -> np.ndarray: + ) -> NDArray: """Applies the pooling function to the input image. - This method applies the pooling function to the input image. + This method applies `pooling_function` to the input image. Parameters ---------- - image: np.ndarray + image: NDArray | torch.Tensor The input image to pool. ksize: int Size of the pooling kernel. @@ -1171,7 +1171,7 @@ def get( Returns ------- - np.ndarray + NDArray | torch.Tensor The pooled image. """ @@ -1188,54 +1188,54 @@ def get( ) -#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test class AveragePooling(Pool): """Apply average pooling to an image. - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the average function to - each block. The result is a downsampled image where each pixel value + `AveragePooling` reduces the resolution of an image by dividing it into + non-overlapping blocks of size `ksize` and applying the `average` function + to each block. The result is a downsampled image where each pixel value represents the average value within the corresponding block of the - original image. + original image. This is useful for reducing the size of an image while + retaining the most significant features. + + If the backend is NumPy, the downsampling is performed using + `skimage.measure.block_reduce`. + If the backend is PyTorch, the downsampling + is performed using `torch.nn.functional.avg_pool2d`. Parameters ---------- ksize: int Size of the pooling kernel. - **kwargs: dict + **kwargs: Any Additional parameters sent to the pooling function. Examples -------- >>> import deeptrack as dt - >>> import numpy as np Create an input image: + >>> import numpy as np + >>> >>> input_image = np.random.rand(32, 32) - Define an average pooling feature: + Define and use a average-pooling feature: + >>> average_pooling = dt.AveragePooling(ksize=4) >>> output_image = average_pooling(input_image) >>> print(output_image.shape) (8, 8) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( - self: Pool, + self: AveragePooling, ksize: PropertyLike[int] = 3, **kwargs: Any, ): """Initialize the parameters for average pooling. - This constructor initializes the parameters for average pooling. + This constructor initializes the parameters for average-pooling. Parameters ---------- @@ -1248,6 +1248,114 @@ def __init__( super().__init__(np.mean, ksize=ksize, **kwargs) + def get( + self: AveragePooling, + image: NDArray[Any] | torch.Tensor, + ksize: int=3, + **kwargs: Any, + ) -> NDArray[Any] | torch.Tensor: + """Average pooling of input. + + Checks the current backend and chooses the appropriate function to pool + the input image, either `._get_torch()` or `._get_numpy()`. + + Parameters + ---------- + image: array or tensor + Input array or tensor to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + array or tensor + The pooled input as `NDArray` or `torch.Tensor` depending on + the backend. + + """ + + if self.get_backend() == "numpy": + return self._get_numpy(image, ksize, **kwargs) + + if self.get_backend() == "torch": + return self._get_torch(image, ksize, **kwargs) + + raise NotImplementedError(f"Backend {self.backend} not supported") + + def _get_numpy( + self: AveragePooling, + image: NDArray[Any], + ksize: int = 3, + **kwargs: Any, + ) -> NDArray[Any]: + """Average pooling with the NumPy backend enabled. + + Returns the result of the image passed to the scikit image + `block_reduce()` function with `np.mean()` as the pooling function. + + Parameters + ---------- + image: NDArray + Input array to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + array + The pooled image as a NumPy array. + + """ + + return utils.safe_call( + skimage.measure.block_reduce, + image=image, + func=np.average, + block_size=ksize, + **kwargs, + ) + + def _get_torch( + self: AveragePooling, + image: torch.Tensor, + ksize: int=3, + **kwargs: Any, + ) -> torch.Tensor: + """Average pooling with the PyTorch backend enabled. + + Returns the result of the image passed to a Pytorch average pooling + layer. + + Parameters + ---------- + image: torch.Tensor + Input tensor to be pooled. + ksize: int + Kernel size of the pooling operation. + + Returns + ------- + torch.Tensor + The pooled image as a `torch.Tensor`. + + """ + + # If input tensor is 2D + if len(image.shape) == 2: + # Add batch dimension for max pooling. + expanded_image = image.unsqueeze(0) + + pooled_image = torch.nn.functional.avg_pool2d( + expanded_image, kernel_size=ksize, + ) + # Remove the expanded dim. + return pooled_image.squeeze(0) + + return torch.nn.functional.avg_pool2d( + image, + kernel_size=ksize, + ) + class MaxPooling(Pool): """Apply max-pooling to images. diff --git a/deeptrack/tests/test_math.py b/deeptrack/tests/test_math.py index 09799e292..227f56bb2 100644 --- a/deeptrack/tests/test_math.py +++ b/deeptrack/tests/test_math.py @@ -83,6 +83,17 @@ def test_Blur(self): #self.assertTrue(xp.all(blurred_image == expected_output)) + + def test_AveragePooling(self): + input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float) + feature = math.AveragePooling(ksize=2) + pooled_image = feature.resolve(input_image) + + expected = xp.asarray([[3.5, 5.5]]) + + self.assertTrue(xp.all(pooled_image == expected)) + self.assertEqual(pooled_image.shape, (1, 2)) + def test_MaxPooling(self): input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float) feature = math.MaxPooling(ksize=2) @@ -109,8 +120,7 @@ def test_MinPooling(self): class TestMath_Torch(TestMath_Numpy): BACKEND = "torch" pass - - + class TestMath(unittest.TestCase): def test_GaussianBlur(self): @@ -130,6 +140,7 @@ def test_AveragePooling(self): pooled_image = feature.resolve(input_image) self.assertTrue(np.all(pooled_image == [[3.5, 5.5]])) + def test_MaxPooling(self): input_image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) feature = math.MaxPooling(ksize=2)