@@ -1663,75 +1663,154 @@ def __init__(
16631663 super ().__init__ (np .median , ksize = ksize , ** kwargs )
16641664
16651665
1666- #TODO ***MG*** revise Resize - torch, typing, docstring, unit test
16671666class Resize (Feature ):
16681667 """Resize an image to a specified size.
16691668
1670- This class is a wrapper around cv2.resize and resizes an image to a
1671- specified size. The `dsize` parameter specifies the desired output size of
1672- the image.
1673- Note that the order of the axes is different in cv2 and numpy. In cv2, the
1674- first axis is the vertical axis, while in numpy it is the horizontal axis.
1675- This is reflected in the default values of the arguments.
1669+ `Resize` resizes an image using:
1670+ - OpenCV (`cv2.resize`) for NumPy arrays.
1671+ - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors.
1672+
1673+ The interpretation of the `dsize` parameter follows the convention
1674+ of the underlying backend:
1675+ - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match
1676+ OpenCV’s default.
1677+ - **PyTorch**: `dsize` is given as `(height, width)`.
16761678
16771679 Parameters
16781680 ----------
1679- dsize: tuple
1680- Size to resize to.
1681+ dsize: PropertyLike[tuple[int, int]]
1682+ The target size. Format depends on backend: `(width, height)` for
1683+ NumPy, `(height, width)` for PyTorch.
16811684 **kwargs: Any
1682- Additional parameters sent to the resizing function.
1685+ Additional parameters sent to the underlying resize function:
1686+ - NumPy: passed to `cv2.resize`.
1687+ - PyTorch: passed to `torch.nn.functional.interpolate`.
1688+
1689+ Methods
1690+ -------
1691+ get(
1692+ image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs
1693+ ) -> np.ndarray | torch.Tensor
1694+ Resize the input image to the specified size.
1695+
1696+ Examples
1697+ --------
1698+ >>> import deeptrack as dt
1699+
1700+ Numpy example:
1701+ >>> import numpy as np
1702+ >>>
1703+ >>> input_image = np.random.rand(16, 16) # Create image
1704+ >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
1705+ >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
1706+ >>> print(resized_image.shape)
1707+ (4, 8)
1708+
1709+ PyTorch example:
1710+ >>> import torch
1711+ >>>
1712+ >>> input_image = torch.rand(1, 1, 16, 16) # Create image
1713+ >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8)
1714+ >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
1715+ >>> print(resized_image.shape)
1716+ torch.Size([1, 1, 4, 8])
16831717
16841718 """
16851719
16861720 def __init__ (
16871721 self : Resize ,
1688- dsize : PropertyLike [tuple ] = (256 , 256 ),
1722+ dsize : PropertyLike [tuple [ int , int ] ] = (256 , 256 ),
16891723 ** kwargs : Any ,
16901724 ):
1691- """Initialize the parameters for resizing input features.
1692-
1693- This constructor initializes the parameters for resizing input
1694- features.
1725+ """Initialize the parameters for the Resize feature.
16951726
16961727 Parameters
16971728 ----------
1698- dsize: tuple
1699- Size to resize to.
1729+ dsize: PropertyLike[tuple[int, int]]
1730+ The target size. Format depends on backend: `(width, height)` for
1731+ NumPy, `(height, width)` for PyTorch. Default is (256, 256).
17001732 **kwargs: Any
1701- Additional keyword arguments.
1733+ Additional arguments passed to the parent `Feature` class .
17021734
17031735 """
17041736
17051737 super ().__init__ (dsize = dsize , ** kwargs )
17061738
1707- def get (self : Resize , image : np .ndarray , dsize : tuple , ** kwargs : Any ) -> np .ndarray :
1739+ def get (
1740+ self : Resize ,
1741+ image : NDArray | torch .Tensor ,
1742+ dsize : tuple [int , int ],
1743+ ** kwargs : Any ,
1744+ ) -> NDArray | torch .Tensor :
17081745 """Resize the input image to the specified size.
17091746
1710- This method resizes the input image to the specified size.
1711-
17121747 Parameters
17131748 ----------
1714- image: np.ndarray
1749+ image: np.ndarray or torch.Tensor
17151750 The input image to resize.
1716- dsize: tuple
1751+ - NumPy arrays may be grayscale (H, W) or color (H, W, C).
1752+ - Torch tensors are expected in one of the following formats:
1753+ (N, C, H, W), (C, H, W), or (H, W).
1754+ dsize: tuple[int, int]
17171755 Desired output size of the image.
1756+ - NumPy: (width, height)
1757+ - PyTorch: (height, width)
17181758 **kwargs: Any
1719- Additional keyword arguments.
1759+ Additional keyword arguments passed to the underlying resize
1760+ function (`cv2.resize` or `torch.nn.functional.interpolate`).
17201761
17211762 Returns
17221763 -------
1723- np.ndarray
1724- The resized image.
1764+ np.ndarray or torch.Tensor
1765+ The resized image in the same type and dimensionality format as
1766+ input.
17251767
1726- """
1768+ Notes
1769+ -----
1770+ - For PyTorch tensors, resizing uses bilinear interpolation with
1771+ `align_corners=False`. This choice matches OpenCV’s `cv2.resize`
1772+ default behavior when resizing NumPy arrays, aiming to produce nearly
1773+ identical results between both backends.
17271774
1728- import cv2
1729- from deeptrack import config
1775+ """
17301776
17311777 if self ._wrap_array_with_image :
17321778 image = strip (image )
17331779
1734- return utils .safe_call (cv2 .resize , positional_args = [image , dsize ], ** kwargs )
1780+ if apc .is_torch_array (image ):
1781+ original_shape = image .shape
1782+
1783+ # Reshape input to (N, C, H, W)
1784+ if image .ndim == 2 : # (H, W)
1785+ image = image .unsqueeze (0 ).unsqueeze (0 )
1786+ elif image .ndim == 3 : # (C, H, W)
1787+ image = image .unsqueeze (0 )
1788+ elif image .ndim != 4 :
1789+ raise ValueError (
1790+ "Resize only supports tensors with shape (N, C, H, W), "
1791+ "(C, H, W), or (H, W)."
1792+ )
1793+
1794+ resized = torch .nn .functional .interpolate (
1795+ image ,
1796+ size = dsize ,
1797+ mode = "bilinear" ,
1798+ align_corners = False ,
1799+ )
1800+
1801+ # Restore original dimensionality
1802+ if len (original_shape ) == 2 :
1803+ resized = resized .squeeze (0 ).squeeze (0 )
1804+ elif len (original_shape ) == 3 :
1805+ resized = resized .squeeze (0 )
1806+
1807+ return resized
1808+
1809+ else :
1810+ import cv2
1811+ return utils .safe_call (
1812+ cv2 .resize , positional_args = [image , dsize ], ** kwargs
1813+ )
17351814
17361815
17371816if OPENCV_AVAILABLE :
0 commit comments