Skip to content

Commit a9f1b29

Browse files
committed
Update optics.py
1 parent 9dbd195 commit a9f1b29

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

deeptrack/optics.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ def _pad_volume(
137137
from __future__ import annotations
138138

139139
from pint import Quantity
140-
from typing import Any
140+
from typing import Any, TYPE_CHECKING
141141
import warnings
142142

143143
import numpy as np
144+
from numpy.typing import NDArray
144145
from scipy.ndimage import convolve
145146

146147
from deeptrack.backend.units import (
@@ -158,6 +159,9 @@ def _pad_volume(
158159
from deeptrack import image
159160
from deeptrack import units_registry as u
160161

162+
if TYPE_CHECKING:
163+
import torch
164+
161165

162166
#TODO ***??*** revise Microscope - torch, typing, docstring, unit test
163167
class Microscope(StructuralFeature):
@@ -207,59 +211,60 @@ class Microscope(StructuralFeature):
207211

208212
__distributed__ = False
209213

214+
_sample: Feature
215+
_objective: Feature
216+
210217
def __init__(
211218
self: Microscope,
212219
sample: Feature,
213220
objective: Feature,
214221
**kwargs: Any,
215222
):
216-
"""Initialize the `Microscope` instance.
223+
"""Initialize a microscope feature combining sample and optics.
224+
225+
This constructor attaches a sample feature (typically a combination of
226+
scatterers) and an objective feature (optical system) to the
227+
microscope.
217228
218229
Parameters
219230
----------
220231
sample: Feature
221-
A feature-set resolving a list of images describing the sample to be
222-
imaged.
232+
Feature that resolves one or more scatterer volumes or fields
233+
representing the sample to be imaged.
223234
objective: Feature
224-
A feature-set defining the optical device that images the sample.
235+
Feature describing the optical system used to image the sample
236+
(e.g., brightfield, fluorescence).
225237
**kwargs: Any
226-
Additional parameters passed to the base `StructuralFeature` class.
227-
228-
Attributes
229-
----------
230-
_sample: Feature
231-
The feature-set defining the sample to be imaged.
232-
_objective: Feature
233-
The feature-set defining the optical system imaging the sample.
238+
Additional keyword arguments passed to the base `StructuralFeature`
239+
class.
234240
235241
"""
236242

237243
super().__init__(**kwargs)
238244

239245
self._sample = self.add_feature(sample)
240246
self._objective = self.add_feature(objective)
247+
248+
#TODO: erase following line when rid of Image
241249
self._sample.store_properties()
242250

243251
def get(
244252
self: Microscope,
245-
image: Image | None,
253+
input: Any = None, # Ignored, kept for API compatibility
246254
**kwargs: Any,
247-
) -> Image:
255+
) -> NDArray[Any] | torch.Tensor:
248256
"""Generate an image of the sample using the defined optical system.
249257
250-
This method processes the sample through the optical system to
251-
produce a simulated image.
252-
253258
Parameters
254259
----------
255-
image: Image | None
256-
The input image to be processed. If None, a new image is created.
260+
image: Any, optional
261+
Ignored. Kept for API compatibility. Defaults to None.
257262
**kwargs: Any
258263
Additional parameters for the imaging process.
259264
260265
Returns
261266
-------
262-
Image: Image
267+
array or tensor
263268
The processed image after applying the optical system.
264269
265270
Examples
@@ -277,47 +282,46 @@ def get(
277282
278283
"""
279284

280-
# Grab properties from the objective to pass to the sample
281-
additional_sample_kwargs = self._objective.properties()
285+
# Grab objective properties to pass to sample
286+
objective_properties = self._objective.properties()
282287

283-
# Calculate required output image for the given upscale
284-
# This way of providing the upscale will be deprecated in the future
285-
# in favor of dt.Upscale().
286-
_upscale_given_by_optics = additional_sample_kwargs["upscale"]
288+
# Calculate required output image for the given upscale.
289+
# This upscale way will be deprecated in favor of dt.Upscale().
290+
_upscale_given_by_optics = objective_properties["upscale"]
287291
if np.array(_upscale_given_by_optics).size == 1:
288292
_upscale_given_by_optics = (_upscale_given_by_optics,) * 3
289293

290294
with u.context(
291295
create_context(
292-
*additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics
296+
*objective_properties["voxel_size"], *_upscale_given_by_optics
293297
)
294298
):
295299

296300
upscale = np.round(get_active_scale())
297301

298-
output_region = additional_sample_kwargs.pop("output_region")
299-
additional_sample_kwargs["output_region"] = [
302+
output_region = objective_properties.pop("output_region")
303+
objective_properties["output_region"] = [
300304
int(o * upsc)
301305
for o, upsc in zip(
302306
output_region, (upscale[0], upscale[1], upscale[0], upscale[1])
303307
)
304308
]
305309

306-
padding = additional_sample_kwargs.pop("padding")
307-
additional_sample_kwargs["padding"] = [
310+
padding = objective_properties.pop("padding")
311+
objective_properties["padding"] = [
308312
int(p * upsc)
309313
for p, upsc in zip(
310314
padding, (upscale[0], upscale[1], upscale[0], upscale[1])
311315
)
312316
]
313317

314318
self._objective.output_region.set_value(
315-
additional_sample_kwargs["output_region"]
319+
objective_properties["output_region"]
316320
)
317-
self._objective.padding.set_value(additional_sample_kwargs["padding"])
321+
self._objective.padding.set_value(objective_properties["padding"])
318322

319323
propagate_data_to_dependencies(
320-
self._sample, **{"return_fft": True, **additional_sample_kwargs}
324+
self._sample, **{"return_fft": True, **objective_properties}
321325
)
322326

323327
list_of_scatterers = self._sample()
@@ -342,7 +346,7 @@ def get(
342346
# Merge all volumes into a single volume.
343347
sample_volume, limits = _create_volume(
344348
volume_samples,
345-
**additional_sample_kwargs,
349+
**objective_properties,
346350
)
347351
sample_volume = Image(sample_volume)
348352

@@ -359,12 +363,16 @@ def get(
359363

360364
imaged_sample = self._objective.resolve(sample_volume)
361365

362-
# Upscale given by the optics needs to be handled separately.
366+
# Handling separately upscale given by optics.
367+
# This upscale way will be deprecated in favor of dt.Upscale().
363368
if _upscale_given_by_optics != (1, 1, 1):
364369
imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))(
365370
imaged_sample
366371
)
367372

373+
return imaged_sample
374+
375+
#TODO: erase rest of the method
368376
# Merge with input
369377
if not image:
370378
if not self._wrap_array_with_image and isinstance(imaged_sample, Image):

0 commit comments

Comments
 (0)