Skip to content

Commit 791fc82

Browse files
committed
Update scatterers.py
1 parent 0463017 commit 791fc82

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

deeptrack/scatterers.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,12 @@
160160

161161
from __future__ import annotations
162162

163-
from typing import Callable
163+
from typing import Any, TYPE_CHECKING
164164
import warnings
165165

166-
from pint import Quantity
167166
import numpy as np
167+
from numpy.typing import NDArray
168+
from pint import Quantity
168169

169170
from deeptrack.holography import get_propagation_matrix
170171
from deeptrack.backend.units import (
@@ -179,6 +180,22 @@
179180
from deeptrack import units_registry as u
180181

181182

183+
__all__ = [
184+
"Scatterer",
185+
"PointParticle",
186+
"Ellipse",
187+
"Sphere",
188+
"Ellipsoid",
189+
"MieScatterer",
190+
"MieSphere",
191+
"MieStratifiedSphere",
192+
]
193+
194+
195+
if TYPE_CHECKING:
196+
import torch
197+
198+
182199
#TODO ***??*** revise Scatterer - torch, typing, docstring, unit test
183200
class Scatterer(Feature):
184201
"""Base abstract class for scatterers.
@@ -248,6 +265,7 @@ def __init__(
248265
)
249266

250267
self._processed_properties = False
268+
251269
super().__init__(
252270
position=position,
253271
z=z,
@@ -264,7 +282,7 @@ def _process_properties(
264282
self,
265283
properties: dict
266284
) -> dict:
267-
285+
268286
# Rescales the position property.
269287
properties = super()._process_properties(properties)
270288
self._processed_properties = True
@@ -323,14 +341,14 @@ def _no_wrap_format_input(
323341
**kwargs
324342
) -> list:
325343
return self._image_wrapped_format_input(*args, **kwargs)
326-
344+
327345
def _no_wrap_process_and_get(
328346
self,
329347
*args,
330348
**feature_input
331349
) -> list:
332350
return self._image_wrapped_process_and_get(*args, **feature_input)
333-
351+
334352
def _no_wrap_process_output(
335353
self,
336354
*args,
@@ -341,15 +359,15 @@ def _no_wrap_process_output(
341359

342360
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
343361
class PointParticle(Scatterer):
344-
"""Generates a point particle
362+
"""Generate a diffraction-limited point particle.
345363
346-
A point particle is approximated by the size of a pixel. For subpixel
347-
positioning, the position is interpolated linearly.
364+
A point particle is approximated by the size of a single pixel or voxel.
365+
For subpixel positioning, the position is interpolated linearly.
348366
349367
Parameters
350368
----------
351369
position: ArrayLike[float, float (, float)]
352-
The position of the particle, length 2 or 3. Third index is optional,
370+
Particle position in 2D or 3D. Third index is optional,
353371
and represents the position in the direction normal to the
354372
camera plane.
355373
@@ -365,17 +383,21 @@ class PointParticle(Scatterer):
365383
"""
366384

367385
def __init__(
368-
self,
369-
**kwargs
386+
self: PointParticle,
387+
**kwargs: Any,
370388
):
389+
"""
390+
391+
"""
392+
371393
super().__init__(upsample=1, upsample_axes=(), **kwargs)
372394

373395
def get(
374-
self,
396+
self: PointParticle,
375397
image: Image | np.ndarray,
376-
**kwarg
377-
) -> ArrayLike[float]:
378-
"""Abstract method to initialize the point scatterer"""
398+
**kwarg: Any,
399+
) -> NDArray[Any] | torch.Tensor:
400+
"""Evaluate and return the scatterer volume."""
379401

380402
scale = get_active_scale()
381403

0 commit comments

Comments
 (0)