Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from concurrent import futures
from functools import partial
from multiprocessing import cpu_count
from typing import List, Optional, Tuple, Union

import numpy as np
Expand All @@ -22,13 +23,13 @@
from tqdm import tqdm

from .file_utils import read_image_data
from .segmentation.postprocessing import compute_table_on_the_fly
from .postprocessing.label_components import compute_table_on_the_fly
import flamingo_tools.s3_utils as s3_utils


def _measure_volume_and_surface(mask, resolution):
# Use marching_cubes for 3D data
verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3)
verts, faces, normals, _ = marching_cubes(mask, spacing=resolution)

mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
surface = mesh.area
Expand Down Expand Up @@ -166,6 +167,8 @@ def _default_object_features(

# Do the volume and surface measurement.
if not median_only:
if isinstance(resolution, float):
resolution = (resolution,) * 3
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
Expand All @@ -181,6 +184,8 @@ def _morphology_features(seg_id, table, image, segmentation, resolution, **kwarg
# Hard-coded value for LaVision cochleae. This is a hack for the wrong voxel size in MoBIE.
# resolution = (3.0, 0.76, 0.76)

if isinstance(resolution, float):
resolution = (resolution,) * 3
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
Expand Down Expand Up @@ -498,3 +503,102 @@ def _compute_block(block_id):

mask = ResizedVolume(low_res_mask, shape=original_shape, order=0)
return mask


def object_measures_single(
table_path: str,
seg_path: str,
image_paths: List[str],
out_paths: List[str],
force_overwrite: bool = False,
component_list: List[int] = [1],
background_mask: Optional[np.typing.ArrayLike] = None,
resolution: List[float] = [0.38, 0.38, 0.38],
s3: bool = False,
s3_credentials: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_service_endpoint: Optional[str] = None,
**_
):
"""Compute object measures for a single or multiple image channels in respect to a single segmentation channel.

Args:
table_path: File path to segmentationt table.
seg_path: Path to segmentation channel in ome.zarr format.
image_paths: Path(s) to image channel(s) in ome.zarr format.
out_paths: Paths(s) for calculated object measures.
force_overwrite: Forcefully overwrite existing files.
component_list: Only calculate object measures for specific components.
background_mask: Use background mask for calculating object measures.
resolution: Resolution of input in micrometer.
s3: Use S3 file paths.
s3_credentials:
s3_bucket_name:
s3_service_endpoint:
"""
input_key = "s0"
out_paths = [os.path.realpath(o) for o in out_paths]

if not isinstance(resolution, float):
if len(resolution) == 1:
resolution = resolution * 3
assert len(resolution) == 3
resolution = np.array(resolution)[::-1]
else:
resolution = (resolution,) * 3

for (img_path, out_path) in zip(image_paths, out_paths):
n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count()))

# overwrite input file
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
force_overwrite = True

if os.path.isfile(out_path) and not force_overwrite:
print(f"Skipping {out_path}. Table already exists.")

else:
if background_mask is None:
feature_set = "default"
dilation = None
median_only = False
else:
print("Using background mask for calculating object measures.")
feature_set = "default_background_subtract"
dilation = 4
median_only = True

if s3:
img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name,
service_endpoint=s3_service_endpoint,
credential_file=s3_credentials)
seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name,
service_endpoint=s3_service_endpoint,
credential_file=s3_credentials)

mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr")
background_mask = compute_sgn_background_mask(
image_path=img_path,
segmentation_path=seg_path,
image_key=input_key,
segmentation_key=input_key,
n_threads=n_threads,
cache_path=mask_cache_path,
)

compute_object_measures(
image_path=img_path,
segmentation_path=seg_path,
segmentation_table_path=table_path,
output_table_path=out_path,
image_key=input_key,
segmentation_key=input_key,
feature_set=feature_set,
s3_flag=s3,
component_list=component_list,
dilation=dilation,
median_only=median_only,
background_mask=background_mask,
n_threads=n_threads,
resolution=resolution,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
from typing import List, Optional, Tuple

import networkx as nx
Expand All @@ -7,7 +8,8 @@
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
from scipy.interpolate import interp1d

from flamingo_tools.segmentation.postprocessing import downscaled_centroids
from flamingo_tools.postprocessing.label_components import downscaled_centroids
from flamingo_tools.s3_utils import get_s3_path


def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
Expand Down Expand Up @@ -528,6 +530,7 @@ def map_frequency(table: pd.DataFrame, animal: str = "mouse", otof: bool = False
Args:
table: Dataframe containing the segmentation.
animal: Select the Greenwood function parameters specific to a species. Either "mouse" or "gerbil".
otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae.

Returns:
Dataframe containing frequency in an additional column 'frequency[kHz]'.
Expand Down Expand Up @@ -749,14 +752,18 @@ def tonotopic_mapping(
apex_higher: bool = True,
otof: bool = False,
) -> pd.DataFrame:
"""Tonotopic mapping of IHCs by supplying a table with component labels.
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
"""Tonotopic mapping of SGNs or IHCs by supplying a table with component labels.
The mapping assigns a tonotopic label to each instance according to the position along the length of the cochlea.

Args:
table: Dataframe of segmentation table.
component_label: List of component labels to evaluate.
components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal.
cell_type: Cell type of segmentation.
animal: Animal specifier for species specific frequency mapping. Either "mouse" or "gerbil".
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
apex_higher: Flag for identifying apex and base. Apex is set to node with higher y-value if True.
otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae.

Returns:
Table with tonotopic label for cells.
Expand Down Expand Up @@ -811,3 +818,74 @@ def tonotopic_mapping(
table = map_frequency(table, animal=animal, otof=otof)

return table


def tonotopic_mapping_single(
table_path: str,
out_path: str,
force_overwrite: bool = False,
cell_type: str = "sgn",
animal: str = "mouse",
otof: bool = False,
apex_position: str = "apex_higher",
component_list: List[int] = [1],
component_mapping: Optional[List[int]] = None,
max_edge_distance: float = 30,
s3: bool = False,
s3_credentials: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_service_endpoint: Optional[str] = None,
**_
):
"""Tonotopic mapping of a single cochlea.
Each segmentation instance within a given component list is assigned a frequency[kHz], a run length and an offset.
The components used for the mapping itself can be a subset of the component list to adapt to broken components
along the Rosenthal's canal.
If the cochlea is broken in the direction of the Rosenthal's canal, the components have to be provided in a
continuous order which reflects the positioning within 3D.
The frequency is calculated using the Greenwood function using animal specific parameters.
The orientation of the mapping can be reversed using the apex position in reference to the y-coordinate.

Args:
table_path: File path to segmentation table.
out_path: Output path to segmentation table with new column "component_labels".
force_overwrite: Forcefully overwrite existing output path.
cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc".
animal: Animal for species specific frequency mapping. Either "mouse" or "gerbil".
otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae.
apex_position: Identify position of apex and base. Apex is set to node with higher y-value per default.
component_list: List of components. Can be passed to obtain the number of instances within the component list.
components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal.
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
s3: Use S3 bucket.
s3_credentials:
s3_bucket_name:
s3_service_endpoint:
"""
if os.path.isdir(out_path):
raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.")

if s3:
tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name,
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
with fs.open(tsv_path, "r") as f:
table = pd.read_csv(f, sep="\t")
else:
table = pd.read_csv(table_path, sep="\t")

apex_higher = (apex_position == "apex_higher")

# overwrite input file
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
force_overwrite = True

if os.path.isfile(out_path) and not force_overwrite:
print(f"Skipping {out_path}. Table already exists.")

else:
table = tonotopic_mapping(table, component_label=component_list, animal=animal,
cell_type=cell_type, component_mapping=component_mapping,
apex_higher=apex_higher, max_edge_distance=max_edge_distance,
otof=otof)

table.to_csv(out_path, sep="\t", index=False)
Loading
Loading