Skip to content

Commit a1ffaf9

Browse files
Update math/resize (#406)
* update math/resize * formatting * update unittesting * update resize * update test resize * update test resize * update resize * update resize * update resize
1 parent 8981c62 commit a1ffaf9

File tree

2 files changed

+142
-32
lines changed

2 files changed

+142
-32
lines changed

deeptrack/math.py

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,75 +1663,154 @@ def __init__(
16631663
super().__init__(np.median, ksize=ksize, **kwargs)
16641664

16651665

1666-
#TODO ***MG*** revise Resize - torch, typing, docstring, unit test
16671666
class Resize(Feature):
16681667
"""Resize an image to a specified size.
16691668
1670-
This class is a wrapper around cv2.resize and resizes an image to a
1671-
specified size. The `dsize` parameter specifies the desired output size of
1672-
the image.
1673-
Note that the order of the axes is different in cv2 and numpy. In cv2, the
1674-
first axis is the vertical axis, while in numpy it is the horizontal axis.
1675-
This is reflected in the default values of the arguments.
1669+
`Resize` resizes an image using:
1670+
- OpenCV (`cv2.resize`) for NumPy arrays.
1671+
- PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors.
1672+
1673+
The interpretation of the `dsize` parameter follows the convention
1674+
of the underlying backend:
1675+
- **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match
1676+
OpenCV’s default.
1677+
- **PyTorch**: `dsize` is given as `(height, width)`.
16761678
16771679
Parameters
16781680
----------
1679-
dsize: tuple
1680-
Size to resize to.
1681+
dsize: PropertyLike[tuple[int, int]]
1682+
The target size. Format depends on backend: `(width, height)` for
1683+
NumPy, `(height, width)` for PyTorch.
16811684
**kwargs: Any
1682-
Additional parameters sent to the resizing function.
1685+
Additional parameters sent to the underlying resize function:
1686+
- NumPy: passed to `cv2.resize`.
1687+
- PyTorch: passed to `torch.nn.functional.interpolate`.
1688+
1689+
Methods
1690+
-------
1691+
get(
1692+
image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs
1693+
) -> np.ndarray | torch.Tensor
1694+
Resize the input image to the specified size.
1695+
1696+
Examples
1697+
--------
1698+
>>> import deeptrack as dt
1699+
1700+
Numpy example:
1701+
>>> import numpy as np
1702+
>>>
1703+
>>> input_image = np.random.rand(16, 16) # Create image
1704+
>>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
1705+
>>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
1706+
>>> print(resized_image.shape)
1707+
(4, 8)
1708+
1709+
PyTorch example:
1710+
>>> import torch
1711+
>>>
1712+
>>> input_image = torch.rand(1, 1, 16, 16) # Create image
1713+
>>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8)
1714+
>>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
1715+
>>> print(resized_image.shape)
1716+
torch.Size([1, 1, 4, 8])
16831717
16841718
"""
16851719

16861720
def __init__(
16871721
self: Resize,
1688-
dsize: PropertyLike[tuple] = (256, 256),
1722+
dsize: PropertyLike[tuple[int, int]] = (256, 256),
16891723
**kwargs: Any,
16901724
):
1691-
"""Initialize the parameters for resizing input features.
1692-
1693-
This constructor initializes the parameters for resizing input
1694-
features.
1725+
"""Initialize the parameters for the Resize feature.
16951726
16961727
Parameters
16971728
----------
1698-
dsize: tuple
1699-
Size to resize to.
1729+
dsize: PropertyLike[tuple[int, int]]
1730+
The target size. Format depends on backend: `(width, height)` for
1731+
NumPy, `(height, width)` for PyTorch. Default is (256, 256).
17001732
**kwargs: Any
1701-
Additional keyword arguments.
1733+
Additional arguments passed to the parent `Feature` class.
17021734
17031735
"""
17041736

17051737
super().__init__(dsize=dsize, **kwargs)
17061738

1707-
def get(self: Resize, image: np.ndarray, dsize: tuple, **kwargs: Any) -> np.ndarray:
1739+
def get(
1740+
self: Resize,
1741+
image: NDArray | torch.Tensor,
1742+
dsize: tuple[int, int],
1743+
**kwargs: Any,
1744+
) -> NDArray | torch.Tensor:
17081745
"""Resize the input image to the specified size.
17091746
1710-
This method resizes the input image to the specified size.
1711-
17121747
Parameters
17131748
----------
1714-
image: np.ndarray
1749+
image: np.ndarray or torch.Tensor
17151750
The input image to resize.
1716-
dsize: tuple
1751+
- NumPy arrays may be grayscale (H, W) or color (H, W, C).
1752+
- Torch tensors are expected in one of the following formats:
1753+
(N, C, H, W), (C, H, W), or (H, W).
1754+
dsize: tuple[int, int]
17171755
Desired output size of the image.
1756+
- NumPy: (width, height)
1757+
- PyTorch: (height, width)
17181758
**kwargs: Any
1719-
Additional keyword arguments.
1759+
Additional keyword arguments passed to the underlying resize
1760+
function (`cv2.resize` or `torch.nn.functional.interpolate`).
17201761
17211762
Returns
17221763
-------
1723-
np.ndarray
1724-
The resized image.
1764+
np.ndarray or torch.Tensor
1765+
The resized image in the same type and dimensionality format as
1766+
input.
17251767
1726-
"""
1768+
Notes
1769+
-----
1770+
- For PyTorch tensors, resizing uses bilinear interpolation with
1771+
`align_corners=False`. This choice matches OpenCV’s `cv2.resize`
1772+
default behavior when resizing NumPy arrays, aiming to produce nearly
1773+
identical results between both backends.
17271774
1728-
import cv2
1729-
from deeptrack import config
1775+
"""
17301776

17311777
if self._wrap_array_with_image:
17321778
image = strip(image)
17331779

1734-
return utils.safe_call(cv2.resize, positional_args=[image, dsize], **kwargs)
1780+
if apc.is_torch_array(image):
1781+
original_shape = image.shape
1782+
1783+
# Reshape input to (N, C, H, W)
1784+
if image.ndim == 2: # (H, W)
1785+
image = image.unsqueeze(0).unsqueeze(0)
1786+
elif image.ndim == 3: # (C, H, W)
1787+
image = image.unsqueeze(0)
1788+
elif image.ndim != 4:
1789+
raise ValueError(
1790+
"Resize only supports tensors with shape (N, C, H, W), "
1791+
"(C, H, W), or (H, W)."
1792+
)
1793+
1794+
resized = torch.nn.functional.interpolate(
1795+
image,
1796+
size=dsize,
1797+
mode="bilinear",
1798+
align_corners=False,
1799+
)
1800+
1801+
# Restore original dimensionality
1802+
if len(original_shape) == 2:
1803+
resized = resized.squeeze(0).squeeze(0)
1804+
elif len(original_shape) == 3:
1805+
resized = resized.squeeze(0)
1806+
1807+
return resized
1808+
1809+
else:
1810+
import cv2
1811+
return utils.safe_call(
1812+
cv2.resize, positional_args=[image, dsize], **kwargs
1813+
)
17351814

17361815

17371816
if OPENCV_AVAILABLE:

deeptrack/tests/test_math.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,40 @@ def test_MedianPooling(self):
157157
@unittest.skipUnless(OPENCV_AVAILABLE, "OpenCV is not installed.")
158158
def test_Resize(self):
159159
input_image = np.random.rand(16, 16)
160-
feature = math.Resize(dsize=(8, 8))
160+
feature = math.Resize(dsize=(8, 4))
161161
resized = feature.resolve(input_image)
162-
self.assertEqual(resized.shape, (8, 8))
162+
163+
self.assertIsInstance(resized, np.ndarray)
164+
self.assertEqual(resized.shape, (4, 8))
165+
166+
@unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.")
167+
def test_Resize_torch(self):
168+
169+
feature = math.Resize(dsize=(4, 8))
170+
171+
input_image = torch.rand(16, 16)
172+
resized = feature.resolve(input_image)
173+
self.assertIsInstance(resized, torch.Tensor)
174+
self.assertEqual(tuple(resized.shape), (4, 8))
175+
176+
if OPENCV_AVAILABLE:
177+
# Compare with NumPy version:
178+
feature_np = math.Resize(dsize=(8, 4))
179+
input_image_np = input_image.numpy()
180+
resized_np = feature_np.resolve(input_image_np)
181+
np.testing.assert_allclose(
182+
resized_np, resized.numpy(), rtol=1e-5, atol=1e-5
183+
)
184+
185+
input_image = torch.rand(3, 16, 16)
186+
resized = feature.resolve(input_image)
187+
self.assertIsInstance(resized, torch.Tensor)
188+
self.assertEqual(tuple(resized.shape), (3, 4, 8))
189+
190+
input_image = torch.rand(1, 1, 16, 16)
191+
resized = feature.resolve(input_image)
192+
self.assertIsInstance(resized, torch.Tensor)
193+
self.assertEqual(tuple(resized.shape), (1, 1, 4, 8))
163194

164195
@unittest.skipUnless(OPENCV_AVAILABLE, "OpenCV is not installed.")
165196
def test_BlurCV2_GaussianBlur(self):

0 commit comments

Comments
 (0)