@@ -387,7 +387,19 @@ class Feature(DeepTrackNode):
387387 behavior.
388388 `bind_arguments(arguments: Feature) -> Feature`
389389 It binds another feature’s properties as arguments to this feature.
390- `plot(input_image: np.ndarray | list[np.ndarray] | Image | list[Image] | None = None, resolve_kwargs: dict | None = None, interval: float | None = None, **kwargs: Any) -> Any`
390+ `plot(
391+ input_image: (
392+ NDArray
393+ | list[NDArray]
394+ | torch.Tensor
395+ | list[torch.Tensor]
396+ | Image
397+ | list[Image]
398+ ) = None,
399+ resolve_kwargs: dict | None = None,
400+ interval: float | None = None,
401+ **kwargs: Any,
402+ ) -> Any`
391403 It visualizes the output of the feature.
392404
393405 **Private and internal methods.**
@@ -1754,54 +1766,91 @@ def bind_arguments(
17541766
17551767 return self
17561768
1757- #TODO ***MG***
17581769 def plot (
17591770 self : Feature ,
1760- input_image : np .ndarray | list [np .ndarray ] | Image | list [Image ] = None ,
1771+ input_image : (
1772+ NDArray
1773+ | list [NDArray ]
1774+ | torch .Tensor
1775+ | list [torch .Tensor ]
1776+ | Image
1777+ | list [Image ]
1778+ ) = None ,
17611779 resolve_kwargs : dict = None ,
17621780 interval : float = None ,
17631781 ** kwargs : Any ,
17641782 ) -> Any :
1765- """Visualizes the output of the feature.
1783+ """Visualize the output of the feature.
17661784
1767- This method resolves the feature and visualizes the result. If the output is
1768- an `Image`, it displays it using `pyplot.imshow`. If the output is a list, it
1769- creates an animation. In Jupyter notebooks, the animation is played inline
1770- using `to_jshtml()`. In scripts, the animation is displayed using the
1785+ `plot()` resolves the feature and visualizes the result. If the output
1786+ is a single image (NumPy array, PyTorch tensor, or Image), it is
1787+ displayed using `pyplot.imshow`. If the output is a list, an animation
1788+ is created. In Jupyter notebooks, the animation is played inline using
1789+ `to_jshtml()`. In scripts, the animation is displayed using the
17711790 matplotlib backend.
17721791
17731792 Any parameters in `kwargs` are passed to `pyplot.imshow`.
17741793
17751794 Parameters
17761795 ----------
1777- input_image: np.ndarray or Image or list[np.ndarray or Image], optional
1778- The input image or list of images passed as an argument to the `resolve`
1779- call. If `None`, uses previously set input values or propagates properties.
1796+ input_image: np.ndarray, torch.tensor, or Image or list[np.ndarray,
1797+ torch.tensor, or Image], optional
1798+ The input image or list of images passed as an argument to the
1799+ `resolve` call. If `None`, uses previously set input values or
1800+ propagates properties.
17801801 resolve_kwargs: dict, optional
17811802 Additional keyword arguments passed to the `resolve` call.
17821803 interval: float, optional
1783- The time between frames in the animation, in milliseconds. The default
1784- value is 33 ms.
1804+ The time between frames in the animation, in milliseconds. The
1805+ default value is 33 ms.
17851806 **kwargs: dict, optional
17861807 Additional keyword arguments passed to `pyplot.imshow`.
1787-
1808+
17881809 Returns
17891810 -------
17901811 Any
17911812 The output of the feature or pipeline after execution.
1813+
1814+ Examples
1815+ --------
1816+ >>> import deeptrack as dt
1817+
1818+ Create an instance of a dummy feature that returns the input:
1819+ >>> feature = dt.DummyFeature()
1820+
1821+ Generate and plot a grayscale image:
1822+ >>> import numpy as np
1823+ >>>
1824+ >>> img = np.random.randint(0, 256, (64, 64))
1825+ >>> feature.plot(img, cmap="gray");
1826+
1827+ Generate and plot a grayscale video:
1828+ >>> video = [np.random.randint(0, 256, (64, 64)) for _ in range(10)]
1829+ >>> feature.plot(video, interval=100, cmap="gray");
1830+
1831+ Generate a grayscale image using torch and plot it:
1832+ >>> import torch
1833+ >>>
1834+ >>> img = torch.randint(0, 256, size=(64, 64))
1835+ >>> feature.plot(img, cmap="gray");
1836+
1837+ Generate a simulated image of a point particle visualized using
1838+ brightfield microscopy and plot it:
1839+ >>> particle = dt.PointParticle()
1840+ >>> optics = dt.Brightfield()
1841+ >>> imaged_particle = optics(particle)
1842+ >>> imaged_particle.plot(cmap="gray");
17921843
17931844 """
17941845
17951846 from IPython .display import HTML , display
17961847
1797- # if input_image is not None:
1798- # input_image = [Image(input_image)]
1799-
18001848 output_image = self .resolve (input_image , ** (resolve_kwargs or {}))
18011849
18021850 # If a list, assume video
18031851 if not isinstance (output_image , list ):
18041852 # Single image
1853+ output_image = xp .squeeze (output_image )
18051854 plt .imshow (output_image , ** kwargs )
18061855 return plt .gca ()
18071856
@@ -1810,11 +1859,14 @@ def plot(
18101859 images = []
18111860 plt .axis ("off" )
18121861 for image in output_image :
1862+ image = xp .squeeze (image )
18131863 images .append ([plt .imshow (image , ** kwargs )])
18141864
18151865 if not interval :
18161866 if isinstance (output_image [0 ], Image ):
1817- interval = output_image [0 ].get_property ("interval" ) or (1 / 30 * 1000 )
1867+ interval = (
1868+ output_image [0 ].get_property ("interval" ) or (1 / 30 * 1000 )
1869+ )
18181870 else :
18191871 interval = 1 / 30 * 1000
18201872
0 commit comments