Skip to content

Commit 8981c62

Browse files
Mg/features plot (#404)
* formatting * added examples * minor changes * update of example * Added one more example * update features_plot * update plot() * update plot
1 parent 12c091c commit 8981c62

File tree

1 file changed

+70
-18
lines changed

1 file changed

+70
-18
lines changed

deeptrack/features.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,19 @@ class Feature(DeepTrackNode):
387387
behavior.
388388
`bind_arguments(arguments: Feature) -> Feature`
389389
It binds another feature’s properties as arguments to this feature.
390-
`plot(input_image: np.ndarray | list[np.ndarray] | Image | list[Image] | None = None, resolve_kwargs: dict | None = None, interval: float | None = None, **kwargs: Any) -> Any`
390+
`plot(
391+
input_image: (
392+
NDArray
393+
| list[NDArray]
394+
| torch.Tensor
395+
| list[torch.Tensor]
396+
| Image
397+
| list[Image]
398+
) = None,
399+
resolve_kwargs: dict | None = None,
400+
interval: float | None = None,
401+
**kwargs: Any,
402+
) -> Any`
391403
It visualizes the output of the feature.
392404
393405
**Private and internal methods.**
@@ -1754,54 +1766,91 @@ def bind_arguments(
17541766

17551767
return self
17561768

1757-
#TODO ***MG***
17581769
def plot(
17591770
self: Feature,
1760-
input_image: np.ndarray | list[np.ndarray] | Image | list[Image] = None,
1771+
input_image: (
1772+
NDArray
1773+
| list[NDArray]
1774+
| torch.Tensor
1775+
| list[torch.Tensor]
1776+
| Image
1777+
| list[Image]
1778+
) = None,
17611779
resolve_kwargs: dict = None,
17621780
interval: float = None,
17631781
**kwargs: Any,
17641782
) -> Any:
1765-
"""Visualizes the output of the feature.
1783+
"""Visualize the output of the feature.
17661784
1767-
This method resolves the feature and visualizes the result. If the output is
1768-
an `Image`, it displays it using `pyplot.imshow`. If the output is a list, it
1769-
creates an animation. In Jupyter notebooks, the animation is played inline
1770-
using `to_jshtml()`. In scripts, the animation is displayed using the
1785+
`plot()` resolves the feature and visualizes the result. If the output
1786+
is a single image (NumPy array, PyTorch tensor, or Image), it is
1787+
displayed using `pyplot.imshow`. If the output is a list, an animation
1788+
is created. In Jupyter notebooks, the animation is played inline using
1789+
`to_jshtml()`. In scripts, the animation is displayed using the
17711790
matplotlib backend.
17721791
17731792
Any parameters in `kwargs` are passed to `pyplot.imshow`.
17741793
17751794
Parameters
17761795
----------
1777-
input_image: np.ndarray or Image or list[np.ndarray or Image], optional
1778-
The input image or list of images passed as an argument to the `resolve`
1779-
call. If `None`, uses previously set input values or propagates properties.
1796+
input_image: np.ndarray, torch.tensor, or Image or list[np.ndarray,
1797+
torch.tensor, or Image], optional
1798+
The input image or list of images passed as an argument to the
1799+
`resolve` call. If `None`, uses previously set input values or
1800+
propagates properties.
17801801
resolve_kwargs: dict, optional
17811802
Additional keyword arguments passed to the `resolve` call.
17821803
interval: float, optional
1783-
The time between frames in the animation, in milliseconds. The default
1784-
value is 33 ms.
1804+
The time between frames in the animation, in milliseconds. The
1805+
default value is 33 ms.
17851806
**kwargs: dict, optional
17861807
Additional keyword arguments passed to `pyplot.imshow`.
1787-
1808+
17881809
Returns
17891810
-------
17901811
Any
17911812
The output of the feature or pipeline after execution.
1813+
1814+
Examples
1815+
--------
1816+
>>> import deeptrack as dt
1817+
1818+
Create an instance of a dummy feature that returns the input:
1819+
>>> feature = dt.DummyFeature()
1820+
1821+
Generate and plot a grayscale image:
1822+
>>> import numpy as np
1823+
>>>
1824+
>>> img = np.random.randint(0, 256, (64, 64))
1825+
>>> feature.plot(img, cmap="gray");
1826+
1827+
Generate and plot a grayscale video:
1828+
>>> video = [np.random.randint(0, 256, (64, 64)) for _ in range(10)]
1829+
>>> feature.plot(video, interval=100, cmap="gray");
1830+
1831+
Generate a grayscale image using torch and plot it:
1832+
>>> import torch
1833+
>>>
1834+
>>> img = torch.randint(0, 256, size=(64, 64))
1835+
>>> feature.plot(img, cmap="gray");
1836+
1837+
Generate a simulated image of a point particle visualized using
1838+
brightfield microscopy and plot it:
1839+
>>> particle = dt.PointParticle()
1840+
>>> optics = dt.Brightfield()
1841+
>>> imaged_particle = optics(particle)
1842+
>>> imaged_particle.plot(cmap="gray");
17921843
17931844
"""
17941845

17951846
from IPython.display import HTML, display
17961847

1797-
# if input_image is not None:
1798-
# input_image = [Image(input_image)]
1799-
18001848
output_image = self.resolve(input_image, **(resolve_kwargs or {}))
18011849

18021850
# If a list, assume video
18031851
if not isinstance(output_image, list):
18041852
# Single image
1853+
output_image = xp.squeeze(output_image)
18051854
plt.imshow(output_image, **kwargs)
18061855
return plt.gca()
18071856

@@ -1810,11 +1859,14 @@ def plot(
18101859
images = []
18111860
plt.axis("off")
18121861
for image in output_image:
1862+
image = xp.squeeze(image)
18131863
images.append([plt.imshow(image, **kwargs)])
18141864

18151865
if not interval:
18161866
if isinstance(output_image[0], Image):
1817-
interval = output_image[0].get_property("interval") or (1 / 30 * 1000)
1867+
interval = (
1868+
output_image[0].get_property("interval") or (1 / 30 * 1000)
1869+
)
18181870
else:
18191871
interval = 1 / 30 * 1000
18201872

0 commit comments

Comments
 (0)