diff --git a/deeptrack/aberrations.py b/deeptrack/aberrations.py index 22905c99f..385f4d595 100644 --- a/deeptrack/aberrations.py +++ b/deeptrack/aberrations.py @@ -77,14 +77,41 @@ from __future__ import annotations import math -from typing import Any +from typing import Any, TYPE_CHECKING import numpy as np +from numpy.typing import NDArray +from deeptrack.backend import TORCH_AVAILABLE from deeptrack.features import Feature +from deeptrack.image import Image from deeptrack.types import PropertyLike from deeptrack.utils import as_list +if TORCH_AVAILABLE: + import torch + +__all__ = [ + "Aberration", + "GaussianApodization", + "Zernike", + "Piston", + "VerticalTilt", + "HorizontalTilt", + "ObliqueAstigmatism", + "Defocus", + "Astigmatism", + "ObliqueTrefoil", + "VerticalComa", + "HorizontalComa", + "Trefoil", + "SphericalAberration", +] + + +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Aberration - torch, docstring, unit test class Aberration(Feature): @@ -114,13 +141,14 @@ class Aberration(Feature): superclass method for further processing. """ + __distributed__: bool = True def _process_and_get( - self: Feature, - image_list: list[np.ndarray], - **kwargs: dict[str, np.ndarray] - ) -> list[np.ndarray]: + self: Aberration, + image_list: list[NDArray[Any] | torch.Tensor | Image], + **kwargs: Any, + ) -> list[NDArray[Any] | torch.Tensor | Image]: """Computes pupil coordinates. Computes pupil coordinates (rho and theta) for each input image and @@ -130,7 +158,7 @@ def _process_and_get( ---------- image_list: list[np.ndarray] A list of 2D input images to be processed. - **kwargs: dict[str, np.ndarray] + **kwargs: Any Additional parameters to be passed to the superclass's `_process_and_get` method. @@ -151,8 +179,12 @@ def _process_and_get( theta = np.arctan2(Y, X) new_list += super()._process_and_get( - [image], rho=rho, theta=theta, **kwargs + [image], + rho=rho, + theta=theta, + **kwargs, ) + return new_list @@ -177,7 +209,7 @@ class GaussianApodization(Aberration): Methods ------- - `get(pupil: np.ndarray, offset: tuple[float, float], sigma: float, rho: np.ndarray, **kwargs: dict[str, Any]) -> np.ndarray` + `get(pupil: np.ndarray, offset: tuple[float, float], sigma: float, rho: np.ndarray, **kwargs: Any) -> np.ndarray` Applies Gaussian apodization to the input pupil function. Examples @@ -199,7 +231,7 @@ def __init__( self: GaussianApodization, sigma: PropertyLike[float] = 1, offset: PropertyLike[tuple[int, int]] = (0, 0), - **kwargs: dict[str, Any] + **kwargs: Any, ) -> None: """Initializes the GaussianApodization class. @@ -222,12 +254,12 @@ def __init__( super().__init__(sigma=sigma, offset=offset, **kwargs) def get( - self: GaussianApodization, - pupil: np.ndarray, - offset: tuple[float, float], - sigma: float, - rho: np.ndarray, - **kwargs: dict[str, Any] + self: GaussianApodization, + pupil: np.ndarray, + offset: tuple[float, float], + sigma: float, + rho: np.ndarray, + **kwargs: Any, ) -> np.ndarray: """Applies Gaussian apodization to the input pupil function. @@ -299,6 +331,7 @@ def get( rho[rho > 1] = np.inf pupil = pupil * np.exp(-((rho / sigma) ** 2)) + return pupil @@ -336,7 +369,7 @@ class Zernike(Aberration): Methods ------- - `get(pupil: np.ndarray, rho: np.ndarray, theta: np.ndarray, n: int | list[int], m: int | list[int], coefficient: float | list[float], **kwargs: dict[str, Any]) -> np.ndarray` + `get(pupil: np.ndarray, rho: np.ndarray, theta: np.ndarray, n: int | list[int], m: int | list[int], coefficient: float | list[float], **kwargs: str) -> np.ndarray` Applies the Zernike phase aberration to the input pupil function. Notes @@ -369,7 +402,7 @@ def __init__( n: PropertyLike[int | list[int]], m: PropertyLike[int | list[int]], coefficient: PropertyLike[float | list[float]] = 1, - **kwargs: dict[str, Any] + **kwargs: str, ) -> None: """ Initializes the Zernike class. diff --git a/deeptrack/tests/test_aberrations.py b/deeptrack/tests/test_aberrations.py index 29f775d9f..fa1807710 100644 --- a/deeptrack/tests/test_aberrations.py +++ b/deeptrack/tests/test_aberrations.py @@ -33,7 +33,7 @@ def testGaussianApodization(self): im = aberrated_particle.resolve(z=z) self.assertIsInstance(im, np.ndarray) self.assertEqual(im.shape, (64, 48, 1)) - + aberrated_particle.store_properties(True) for z in (-100, 0, 100): im = aberrated_particle.resolve(z=z)