160160
161161from __future__ import annotations
162162
163- from typing import Callable
163+ from typing import Any , TYPE_CHECKING
164164import warnings
165165
166- from pint import Quantity
167166import numpy as np
167+ from numpy .typing import NDArray
168+ from pint import Quantity
168169
169170from deeptrack .holography import get_propagation_matrix
170171from deeptrack .backend .units import (
179180from deeptrack import units_registry as u
180181
181182
183+ __all__ = [
184+ "Scatterer" ,
185+ "PointParticle" ,
186+ "Ellipse" ,
187+ "Sphere" ,
188+ "Ellipsoid" ,
189+ "MieScatterer" ,
190+ "MieSphere" ,
191+ "MieStratifiedSphere" ,
192+ ]
193+
194+
195+ if TYPE_CHECKING :
196+ import torch
197+
198+
182199#TODO ***??*** revise Scatterer - torch, typing, docstring, unit test
183200class Scatterer (Feature ):
184201 """Base abstract class for scatterers.
@@ -248,6 +265,7 @@ def __init__(
248265 )
249266
250267 self ._processed_properties = False
268+
251269 super ().__init__ (
252270 position = position ,
253271 z = z ,
@@ -264,7 +282,7 @@ def _process_properties(
264282 self ,
265283 properties : dict
266284 ) -> dict :
267-
285+
268286 # Rescales the position property.
269287 properties = super ()._process_properties (properties )
270288 self ._processed_properties = True
@@ -323,14 +341,14 @@ def _no_wrap_format_input(
323341 ** kwargs
324342 ) -> list :
325343 return self ._image_wrapped_format_input (* args , ** kwargs )
326-
344+
327345 def _no_wrap_process_and_get (
328346 self ,
329347 * args ,
330348 ** feature_input
331349 ) -> list :
332350 return self ._image_wrapped_process_and_get (* args , ** feature_input )
333-
351+
334352 def _no_wrap_process_output (
335353 self ,
336354 * args ,
@@ -341,15 +359,15 @@ def _no_wrap_process_output(
341359
342360#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
343361class PointParticle (Scatterer ):
344- """Generates a point particle
362+ """Generate a diffraction-limited point particle.
345363
346- A point particle is approximated by the size of a pixel. For subpixel
347- positioning, the position is interpolated linearly.
364+ A point particle is approximated by the size of a single pixel or voxel.
365+ For subpixel positioning, the position is interpolated linearly.
348366
349367 Parameters
350368 ----------
351369 position: ArrayLike[float, float (, float)]
352- The position of the particle, length 2 or 3 . Third index is optional,
370+ Particle position in 2D or 3D . Third index is optional,
353371 and represents the position in the direction normal to the
354372 camera plane.
355373
@@ -365,17 +383,21 @@ class PointParticle(Scatterer):
365383 """
366384
367385 def __init__ (
368- self ,
369- ** kwargs
386+ self : PointParticle ,
387+ ** kwargs : Any ,
370388 ):
389+ """
390+
391+ """
392+
371393 super ().__init__ (upsample = 1 , upsample_axes = (), ** kwargs )
372394
373395 def get (
374- self ,
396+ self : PointParticle ,
375397 image : Image | np .ndarray ,
376- ** kwarg
377- ) -> ArrayLike [ float ] :
378- """Abstract method to initialize the point scatterer"""
398+ ** kwarg : Any ,
399+ ) -> NDArray [ Any ] | torch . Tensor :
400+ """Evaluate and return the scatterer volume. """
379401
380402 scale = get_active_scale ()
381403
0 commit comments