Skip to content

Commit 2a89062

Browse files
Mg/features loadimage (#398)
* update Loadimage * update test_LoadImage * incorporated feedback from Alex * remove one test to avoid problems when unittesting on windows * adding gc.collect() to avoid problems when unittesting * update LoadImage * update LoadImage * Update Loadimage
1 parent 6a3df3d commit 2a89062

File tree

2 files changed

+121
-33
lines changed

2 files changed

+121
-33
lines changed

deeptrack/features.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def propagate_data_to_dependencies(
216216
"Merge",
217217
"OneOf",
218218
"OneOfDict",
219-
"LoadImage", # TODO ***MG***
219+
"LoadImage",
220220
"SampleToMasks", # TODO ***MG***
221221
"AsType", # TODO ***MG***
222222
"ChannelFirst2d",
@@ -7195,29 +7195,29 @@ def get(
71957195
class LoadImage(Feature):
71967196
"""Load an image from disk and preprocess it.
71977197
7198-
This feature loads an image file using multiple fallback file readers
7199-
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a suitable reader is
7200-
found. The image can be optionally converted to grayscale, reshaped to
7201-
ensure a minimum number of dimensions, or treated as a list of images if
7198+
`LoadImage` loads an image file using multiple fallback file readers
7199+
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a suitable reader is
7200+
found. The image can be optionally converted to grayscale, reshaped to
7201+
ensure a minimum number of dimensions, or treated as a list of images if
72027202
multiple paths are provided.
72037203
72047204
Parameters
72057205
----------
72067206
path: PropertyLike[str or list[str]]
7207-
The path(s) to the image(s) to load. Can be a single string or a list
7207+
The path(s) to the image(s) to load. Can be a single string or a list
72087208
of strings.
72097209
load_options: PropertyLike[dict[str, Any]], optional
72107210
Additional options passed to the file reader. It defaults to `None`.
72117211
as_list: PropertyLike[bool], optional
7212-
If `True`, the first dimension of the image will be treated as a list.
7212+
If `True`, the first dimension of the image will be treated as a list.
72137213
It defaults to `False`.
72147214
ndim: PropertyLike[int], optional
72157215
Ensures the image has at least this many dimensions. It defaults to
72167216
`3`.
72177217
to_grayscale: PropertyLike[bool], optional
72187218
If `True`, converts the image to grayscale. It defaults to `False`.
72197219
get_one_random: PropertyLike[bool], optional
7220-
If `True`, extracts a single random image from a stack of images. Only
7220+
If `True`, extracts a single random image from a stack of images. Only
72217221
used when `as_list` is `True`. It defaults to `False`.
72227222
72237223
Attributes
@@ -7228,22 +7228,36 @@ class LoadImage(Feature):
72287228
72297229
Methods
72307230
-------
7231-
`get(image: Any, path: str or list[str], load_options: dict[str, Any] | None, ndim: int, to_grayscale: bool, as_list: bool, get_one_random: bool, **kwargs: Any) -> array`
7231+
`get(
7232+
path: str | list[str],
7233+
load_options: dict[str, Any] | None,
7234+
ndim: int,
7235+
to_grayscale: bool,
7236+
as_list: bool,
7237+
get_one_random: bool,
7238+
**kwargs: Any,
7239+
) -> NDArray | list[NDArray] | torch.Tensor | list[torch.Tensor]`
72327240
Load the image(s) from disk and process them.
72337241
72347242
Raises
72357243
------
72367244
IOError
72377245
If no file reader could parse the file or the file does not exist.
72387246
7247+
Notes
7248+
----
7249+
By default, `LoadImage` returns a NumPy array. If you want the output as
7250+
a PyTorch tensor, convert the feature to torch by calling `.torch()` before
7251+
resolving.
7252+
72397253
Examples
72407254
--------
72417255
>>> import deeptrack as dt
72427256
72437257
Create a temporary image file:
72447258
>>> import numpy as np
72457259
>>> import os, tempfile
7246-
>>>
7260+
>>>
72477261
>>> temp_file = tempfile.NamedTemporaryFile(suffix=".npy", delete=False)
72487262
>>> np.save(temp_file.name, np.random.rand(100, 100, 3))
72497263
@@ -7271,7 +7285,14 @@ class LoadImage(Feature):
72717285
... )
72727286
>>> loaded_image = load_image_feature.resolve()
72737287
>>> loaded_image.shape
7274-
(2, 2, 3, 1)
7288+
(100, 100, 3, 1)
7289+
7290+
Load an image as a PyTorch tensor by setting the backend of the feature:
7291+
>>> load_image_feature = dt.LoadImage(path=temp_file.name)
7292+
>>> load_image_feature.torch()
7293+
>>> loaded_image = load_image_feature.resolve()
7294+
>>> type(loaded_image)
7295+
<class 'torch.Tensor'>
72757296
72767297
Cleanup the temporary file:
72777298
>>> os.remove(temp_file.name)
@@ -7313,7 +7334,7 @@ def __init__(
73137334
If `True`, selects a single random image from a stack when
73147335
`as_list=True`. It defaults to `False`.
73157336
**kwargs: Any
7316-
Additional keyword arguments passed to the parent `Feature` class,
7337+
Additional keyword arguments passed to the parent `Feature` class,
73177338
allowing further customization.
73187339
73197340
"""
@@ -7338,31 +7359,36 @@ def get(
73387359
as_list: bool,
73397360
get_one_random: bool,
73407361
**kwargs: Any,
7341-
) -> NDArray | torch.Tensor:
7362+
) -> NDArray[Any] | torch.Tensor | list:
73427363
"""Load and process an image or a list of images from disk.
73437364
7344-
This method attempts to load an image using multiple file readers
7345-
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a valid format is
7365+
This method attempts to load an image using multiple file readers
7366+
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a valid format is
73467367
found. It supports optional processing steps such as ensuring a minimum
7347-
number of dimensions, grayscale conversion, and treating multi-frame
7368+
number of dimensions, grayscale conversion, and treating multi-frame
73487369
images as lists.
73497370
7371+
The output is returned as a NumPy array by default. If `as_list=True`,
7372+
the result is a Python list of arrays. If the backend of the feature is
7373+
`"torch"`, the image is returned as a PyTorch tensor.
7374+
73507375
Parameters
73517376
----------
73527377
path: str or list[str]
7353-
The file path(s) to the image(s) to be loaded. A single string
7378+
The file path(s) to the image(s) to be loaded. A single string
73547379
loads one image, while a list of paths loads multiple images.
73557380
load_options: dict of str to Any, optional
7356-
Additional options passed to the file reader (e.g., `allow_pickle`
7381+
Additional options passed to the file reader (e.g., `allow_pickle`
73577382
for NumPy, `mode` for OpenCV). It defaults to `None`.
73587383
ndim: int
7359-
Ensures the image has at least this many dimensions. If the loaded
7360-
image has fewer dimensions, extra dimensions are added.
7384+
Ensures the image has at least this many dimensions. If the loaded
7385+
image has fewer dimensions, extra dimensions are added. It defaults
7386+
to `3`.
73617387
to_grayscale: bool
73627388
If `True`, converts the image to grayscale. It defaults to `False`.
73637389
as_list: bool
7364-
If `True`, treats the first dimension as a list of images instead
7365-
of stacking them into a NumPy array.
7390+
If `True`, treats the first dimension as a list of images instead
7391+
of stacking them into a NumPy array. It defaults to `False`.
73667392
get_one_random: bool
73677393
If `True`, selects a single random image from a multi-frame stack
73687394
when `as_list=True`. It defaults to `False`.
@@ -7372,14 +7398,14 @@ def get(
73727398
Returns
73737399
-------
73747400
array
7375-
The loaded and processed image(s). If `as_list=True`, returns a
7401+
The loaded and processed image(s). If `as_list=True`, returns a
73767402
list of images; otherwise, returns a single NumPy array or PyTorch
73777403
tensor.
73787404
73797405
Raises
73807406
------
73817407
IOError
7382-
If no valid file reader is found or if the specified file does not
7408+
If no valid file reader is found or if the specified file does not
73837409
exist.
73847410
73857411
"""
@@ -7402,8 +7428,9 @@ def get(
74027428
try:
74037429
import PIL.Image
74047430

7405-
image = [PIL.Image.open(file, **load_options)
7406-
for file in path]
7431+
image = [
7432+
PIL.Image.open(file, **load_options) for file in path
7433+
]
74077434
except (IOError, ImportError):
74087435
import cv2
74097436

@@ -7439,11 +7466,18 @@ def get(
74397466
)
74407467

74417468
# Ensure the image has at least `ndim` dimensions.
7442-
while ndim and image.ndim < ndim:
7443-
image = np.expand_dims(image, axis=-1)
7469+
if not isinstance(image, list) and ndim:
7470+
while image.ndim < ndim:
7471+
image = np.expand_dims(image, axis=-1)
74447472

74457473
# Convert to PyTorch tensor if needed.
7446-
#TODO
7474+
if self.get_backend() == "torch":
7475+
7476+
# Convert to stack if needed.
7477+
if isinstance(image, list):
7478+
image = np.stack(image, axis=0)
7479+
7480+
image = torch.from_numpy(image)
74477481

74487482
return image
74497483

deeptrack/tests/test_features.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,9 +1836,14 @@ def test_LoadImage(self):
18361836

18371837
try:
18381838
with NamedTemporaryFile(suffix=".npy", delete=False) as temp_npy:
1839-
np.save(temp_npy.name, test_image_array)
1839+
pass
1840+
np.save(temp_npy.name, test_image_array)
18401841
# npy_filename = temp_npy.name
18411842

1843+
with NamedTemporaryFile(suffix=".npy", delete=False) as temp_npy2:
1844+
pass
1845+
np.save(temp_npy2.name, test_image_array)
1846+
18421847
with NamedTemporaryFile(suffix=".png", delete=False) as temp_png:
18431848
PIL_Image.fromarray(test_image_array).save(temp_png.name)
18441849
# png_filename = temp_png.name
@@ -1877,12 +1882,61 @@ def test_LoadImage(self):
18771882
loaded_image = load_feature.resolve()
18781883
self.assertGreaterEqual(len(loaded_image.shape), 4)
18791884

1885+
# Test loading a list of images
1886+
load_feature = features.LoadImage(
1887+
path=[temp_npy.name, temp_npy2.name], as_list=True
1888+
)
1889+
loaded_list = load_feature.resolve()
1890+
self.assertIsInstance(loaded_list, list)
1891+
self.assertEqual(len(loaded_list), 2)
1892+
1893+
for img in loaded_list:
1894+
self.assertTrue(isinstance(img, np.ndarray))
1895+
1896+
# Test loading a random image from a list of images
1897+
load_feature = features.LoadImage(
1898+
path=[temp_npy.name, temp_npy2.name],
1899+
ndim=4,
1900+
as_list=True,
1901+
get_one_random=True,
1902+
)
1903+
loaded_image = load_feature.resolve()
1904+
self.assertTrue(
1905+
np.allclose(
1906+
loaded_image[:, :, 0, 0], test_image_array, rtol=1.e-3
1907+
)
1908+
)
1909+
self.assertEqual(loaded_image.shape, (50, 50, 1, 1))
1910+
1911+
import gc
1912+
gc.collect()
1913+
1914+
# Test loading an image as a torch tensor.
1915+
if TORCH_AVAILABLE:
1916+
load_feature = features.LoadImage(path=temp_png.name)
1917+
load_feature.torch()
1918+
loaded_image = load_feature.resolve()
1919+
self.assertIsInstance(loaded_image, torch.Tensor)
1920+
self.assertEqual(
1921+
loaded_image.shape[:2], test_image_array.shape
1922+
)
1923+
1924+
loaded_image_np = loaded_image.numpy()
1925+
self.assertTrue(
1926+
np.allclose(
1927+
test_image_array, loaded_image_np[:, :, 0], rtol=1.e-3
1928+
)
1929+
)
1930+
18801931
finally:
1881-
for file in [temp_npy.name, temp_png.name, temp_jpg.name]:
1932+
for file in [
1933+
temp_npy.name,
1934+
temp_png.name,
1935+
temp_jpg.name,
1936+
temp_npy2.name
1937+
]:
18821938
os.remove(file)
18831939

1884-
#TODO: Add a test for loading a list of images.
1885-
18861940

18871941
def test_SampleToMasks(self):
18881942
# Parameters

0 commit comments

Comments
 (0)