From ff2d5511671324ac1f0b2c6d6cbe57da52446b79 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Fri, 21 Nov 2025 16:27:54 +0100 Subject: [PATCH 1/5] Re-factor component labeling for usage with CLI --- flamingo_tools/segmentation/postprocessing.py | 75 +------- .../repro_label_components.py | 175 +++++++++++------- 2 files changed, 111 insertions(+), 139 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 888944f..a585315 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -363,110 +363,43 @@ def graph_connected_components(coords: dict, max_edge_distance: float, min_compo def components_sgn( table: pd.DataFrame, - keyword: str = "distance_nn100", - threshold_erode: Optional[float] = None, min_component_length: int = 50, max_edge_distance: float = 30, - iterations_erode: int = 0, - postprocess_threshold: Optional[float] = None, - postprocess_components: Optional[List[int]] = None, ) -> List[List[int]]: """Eroding the SGN segmentation. Args: table: Dataframe of segmentation table. - keyword: Keyword of the dataframe column for erosion. - threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. - iterations_erode: Number of iterations for erosion. - postprocess_threshold: Post-process graph connected components by searching for points closer than threshold. - postprocess_components: Post-process specific graph connected components ([0] for largest component only). Returns: Subgraph components as lists of label_ids of dataframe. """ - if keyword not in table: - distance_avg = nearest_neighbor_distance(table, n_neighbors=100) - table.loc[:, keyword] = list(distance_avg) - centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) labels = [int(i) for i in list(table["label_id"])] - - distance_nn = list(table[keyword]) - distance_nn.sort() - - if len(table) < 20000: - min_cells = None - average_dist = int(distance_nn[int(len(table) * 0.8)]) - threshold = threshold_erode if threshold_erode is not None else average_dist - else: - min_cells = 20000 - threshold = threshold_erode if threshold_erode is not None else 40 - - if iterations_erode != 0 and iterations_erode is not None: - print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") - new_subset = erode_subset(table.copy(), iterations=iterations_erode, - threshold=threshold, min_cells=min_cells, keyword=keyword) - else: - new_subset = table.copy() - - # create graph from coordinates of eroded subset - centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) - labels_subset = [int(i) for i in list(new_subset["label_id"])] coords = {} - for index, element in zip(labels_subset, centroids_subset): + for index, element in zip(labels, centroids): coords[index] = element components, _ = graph_connected_components(coords, max_edge_distance, min_component_length) - # add original coordinates closer to eroded component than threshold - if postprocess_threshold is not None: - if postprocess_components is None: - pp_components = components - else: - pp_components = [components[i] for i in postprocess_components] - - add_coords = [] - for label_id, centr in zip(labels, centroids): - if label_id not in labels_subset: - add_coord = [] - for comp_index, component in enumerate(pp_components): - for comp_label in component: - dist = math.dist(centr, centroids[comp_label - 1]) - if dist <= postprocess_threshold: - add_coord.append([comp_index, label_id]) - break - if len(add_coord) != 0: - add_coords.append(add_coord) - if len(add_coords) != 0: - for c in add_coords: - components[c[0][0]].append(c[0][1]) - return components def label_components_sgn( table: pd.DataFrame, min_size: int = 1000, - threshold_erode: Optional[float] = None, min_component_length: int = 50, max_edge_distance: float = 30, - iterations_erode: int = 0, - postprocess_threshold: Optional[float] = None, - postprocess_components: Optional[List[int]] = None, ) -> List[int]: """Label SGN components using graph connected components. Args: table: Dataframe of segmentation table. min_size: Minimal number of pixels for filtering small instances. - threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. - iterations_erode: Number of iterations for erosion. - postprocess_threshold: Post-process graph connected components by searching for points closer than threshold. - postprocess_components: Post-process specific graph connected components ([0] for largest component only). Returns: List of component label for each point in dataframe. 0 - background, then in descending order of size @@ -476,10 +409,8 @@ def label_components_sgn( entries_filtered = table[table.n_pixels < min_size] table = table[table.n_pixels >= min_size] - components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, - max_edge_distance=max_edge_distance, iterations_erode=iterations_erode, - postprocess_threshold=postprocess_threshold, - postprocess_components=postprocess_components) + components = components_sgn(table, min_component_length=min_component_length, + max_edge_distance=max_edge_distance) # add size-filtered objects to have same initial length table = pd.concat([table, entries_filtered], ignore_index=True) diff --git a/reproducibility/label_components/repro_label_components.py b/reproducibility/label_components/repro_label_components.py index be98510..34912b0 100644 --- a/reproducibility/label_components/repro_label_components.py +++ b/reproducibility/label_components/repro_label_components.py @@ -1,7 +1,7 @@ import argparse import json import os -from typing import Optional +from typing import List, Optional import pandas as pd from flamingo_tools.s3_utils import get_s3_path @@ -67,86 +67,117 @@ def label_custom_components(tsv_table, custom_dict): return tsv_table -def repro_label_components( - ddict: dict, - output_dir: str, +def _load_json_as_list(ddict_path: str) -> List[dict]: + with open(ddict_path, "r") as f: + data = json.loads(f.read()) + # ensure the result is always a list + return data if isinstance(data, list) else [data] + + +def label_components_single( + table_path: str, + out_path: str, + cell_type: str = "sgn", + component_list: List[int] = [1], + max_edge_distance: float = 30, + min_component_length: int = 50, + min_size: int = 1000, + s3: bool = False, s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, + custom_dic: Optional[dict] = None, + **_ ): - default_cell_type = "sgn" - default_component_list = [1] - default_iterations_erode = None - default_max_edge_distance = 30 - default_min_length = 50 - default_min_size = 1000 - default_seg_channel = "SGN_v2" - default_threshold_erode = None - - with open(ddict, "r") as myfile: - data = myfile.read() - param_dicts = json.loads(data) - - for dic in param_dicts: - cochlea = dic["cochlea"] - print(f"\n{cochlea}") - - cell_type = dic.get("cell_type", default_cell_type) - component_list = dic.get("component_list", default_component_list) - iterations_erode = dic.get("iterations_erode", default_iterations_erode) - max_edge_distance = dic.get("max_edge_distance", default_max_edge_distance) - min_component_length = dic.get("min_component_length", default_min_length) - min_size = dic.get("min_size", default_min_size) - table_name = dic.get("segmentation_channel", default_seg_channel) - threshold_erode = dic.get("threshold_erode", default_threshold_erode) - - s3_path = os.path.join(f"{cochlea}", "tables", table_name, "default.tsv") - tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, + """Process a single cochlea using one set of parameters or a custom_dic. + """ + if s3: + tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, "r") as f: - table = pd.read_csv(f, sep="\t") - - if "custom_dic" in list(dic.keys()): - print(len(table[table["component_labels"] == 1])) - tsv_table = label_custom_components(table, dic["custom_dic"]) + with fs.open(tsv_path, "r") as f: + table = pd.read_csv(f, sep="\t") + + if custom_dic is not None: + tsv_table = label_custom_components(table, custom_dic) + else: + if cell_type == "sgn": + tsv_table = label_components_sgn(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + elif cell_type == "ihc": + tsv_table = label_components_ihc(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) else: - if cell_type == "sgn": - tsv_table = label_components_sgn(table, min_size=min_size, - threshold_erode=threshold_erode, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance, - iterations_erode=iterations_erode) - elif cell_type == "ihc": - tsv_table = label_components_ihc(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) - else: - raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") + raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") - custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) - print(f"Total {cell_type.upper()}s: {len(tsv_table)}") - if component_list == [1]: - print(f"Largest component has {custom_comp} {cell_type.upper()}s.") - else: - for comp in component_list: - print(f"Component {comp} has {len(tsv_table[tsv_table["component_labels"] == comp])} instances.") - print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") + custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) + print(f"Total {cell_type.upper()}s: {len(tsv_table)}") + if component_list == [1]: + print(f"Largest component has {custom_comp} {cell_type.upper()}s.") + else: + for comp in component_list: + num_instances = len(tsv_table[tsv_table["component_labels"] == comp]) + print(f"Component {comp} has {num_instances} instances.") + print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") - cochlea_str = "-".join(cochlea.split("_")) - table_str = "-".join(table_name.split("_")) - os.makedirs(output_dir, exist_ok=True) - out_path = os.path.join(output_dir, "_".join([cochlea_str, f"{table_str}.tsv"])) + tsv_table.to_csv(out_path, sep="\t", index=False) - tsv_table.to_csv(out_path, sep="\t", index=False) + +def repro_label_components( + output_path: str, + table_path: Optional[str] = None, + ddict: Optional[str] = None, + **kwargs +): + """Wrapper function for labeling connected components using a segmentation table. + The function is used to distinguish between a passed parameter dictionary in JSON format + and the explicit setting of parameters. + """ + if ddict is None: + label_components_single(table_path, output_path, **kwargs) + else: + param_dicts = _load_json_as_list(ddict) + for params in param_dicts: + + cochlea = params["cochlea"] + print(f"\n{cochlea}") + seg_channel = params["segmentation_channel"] + table_path = os.path.join(f"{cochlea}", "tables", seg_channel, "default.tsv") + + if os.path.isdir(output_path): + cochlea_str = "-".join(cochlea.split("_")) + table_str = "-".join(seg_channel.split("_")) + save_path = os.path.join(output_path, "_".join([cochlea_str, f"{table_str}.tsv"])) + else: + save_path = output_path + label_components_single(table_path=table_path, out_path=save_path, **params) def main(): parser = argparse.ArgumentParser( description="Script to label segmentation using a segmentation table and graph connected components.") - parser.add_argument("-i", "--input", type=str, required=True, help="Input JSON dictionary.") - parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.") + parser.add_argument("-o", "--output", type=str, required=True, + help="Output path. Either directory or specific file.") + + parser.add_argument("-i", "--input", type=str, default=None, help="Input path to segmentation table.") + parser.add_argument("-j", "--json", type=str, default=None, help="Input JSON dictionary.") + + parser.add_argument("--cell_type", type=str, default="sgn", + help="Cell type of segmentation. Either 'sgn' or 'ihc'.") + + # options for post-processing + parser.add_argument("--min_size", type=int, default=1000, + help="Minimal number of pixels for filtering small instances.") + parser.add_argument("--min_component_length", type=int, default=50, + help="Minimal length for filtering out connected components.") + parser.add_argument("--max_edge_distance", type=float, default=30, + help="Maximal distance in micrometer between points to create edges for connected components.") + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of connected components.") + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") @@ -158,8 +189,18 @@ def main(): args = parser.parse_args() repro_label_components( - args.input, args.output, - args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + output_path=args.output, + table_path=args.input, + ddict=args.json, + cell_type=args.cell_type, + component_list=args.components, + max_edge_distance=args.max_edge_distance, + min_component_length=args.min_component_length, + min_size=args.min_size, + s3=args.s3, + s3_credentials=args.s3_credentials, + s3_bucket_name=args.s3_bucket_name, + s3_service_endpoint=args.s3_service_endpoint, ) From f6e2183c10a167992a3268eb77ad3435bf4ce374 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Mon, 24 Nov 2025 15:20:49 +0100 Subject: [PATCH 2/5] Re-ordering post-processing functions --- flamingo_tools/measurements.py | 2 +- .../chreef_utils.py | 0 .../cochlea_mapping.py | 2 +- .../label_components.py} | 0 .../sgn_subtype_utils.py | 0 .../postprocessing/synapse_per_ihc_utils.py | 50 +++++++++++++++++++ flamingo_tools/test_data.py | 2 +- .../repro_equidistant_centers.py | 2 +- .../repro_label_components.py | 2 +- .../repro_tonotopic_mapping.py | 2 +- scripts/assign_subtypes.py | 2 +- scripts/export_lower_resolution.py | 2 +- scripts/export_lower_resolution_subtypes.py | 2 +- scripts/figures/plot_SGNsub_thresholds.py | 2 +- scripts/figures/plot_fig3.py | 2 +- scripts/figures/plot_supp_fig3.py | 3 +- scripts/la-vision/segment_sgns.py | 2 +- .../evaluate_marker_annotations_subtype.py | 4 +- scripts/measurements/measure_synapses.py | 33 +++--------- scripts/measurements/sgn_subtypes.py | 2 +- scripts/measurements/subtype_overview.py | 2 +- scripts/prediction/expand_seg_table.py | 2 +- scripts/prediction/postprocess_seg.py | 4 +- scripts/prediction/tonotopic_mapping.py | 2 +- .../add_synapse_per_ihc.py | 24 ++++----- test/test_segmentation/test_postprocessing.py | 10 ++-- 26 files changed, 95 insertions(+), 65 deletions(-) rename flamingo_tools/{segmentation => postprocessing}/chreef_utils.py (100%) rename flamingo_tools/{segmentation => postprocessing}/cochlea_mapping.py (99%) rename flamingo_tools/{segmentation/postprocessing.py => postprocessing/label_components.py} (100%) rename flamingo_tools/{segmentation => postprocessing}/sgn_subtype_utils.py (100%) create mode 100644 flamingo_tools/postprocessing/synapse_per_ihc_utils.py diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py index ab6fc9d..528de37 100644 --- a/flamingo_tools/measurements.py +++ b/flamingo_tools/measurements.py @@ -22,7 +22,7 @@ from tqdm import tqdm from .file_utils import read_image_data -from .segmentation.postprocessing import compute_table_on_the_fly +from .postprocessing.label_components import compute_table_on_the_fly import flamingo_tools.s3_utils as s3_utils diff --git a/flamingo_tools/segmentation/chreef_utils.py b/flamingo_tools/postprocessing/chreef_utils.py similarity index 100% rename from flamingo_tools/segmentation/chreef_utils.py rename to flamingo_tools/postprocessing/chreef_utils.py diff --git a/flamingo_tools/segmentation/cochlea_mapping.py b/flamingo_tools/postprocessing/cochlea_mapping.py similarity index 99% rename from flamingo_tools/segmentation/cochlea_mapping.py rename to flamingo_tools/postprocessing/cochlea_mapping.py index 4209015..b6df6ef 100644 --- a/flamingo_tools/segmentation/cochlea_mapping.py +++ b/flamingo_tools/postprocessing/cochlea_mapping.py @@ -7,7 +7,7 @@ from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing from scipy.interpolate import interp1d -from flamingo_tools.segmentation.postprocessing import downscaled_centroids +from flamingo_tools.postprocessing.label_components import downscaled_centroids def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]: diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/postprocessing/label_components.py similarity index 100% rename from flamingo_tools/segmentation/postprocessing.py rename to flamingo_tools/postprocessing/label_components.py diff --git a/flamingo_tools/segmentation/sgn_subtype_utils.py b/flamingo_tools/postprocessing/sgn_subtype_utils.py similarity index 100% rename from flamingo_tools/segmentation/sgn_subtype_utils.py rename to flamingo_tools/postprocessing/sgn_subtype_utils.py diff --git a/flamingo_tools/postprocessing/synapse_per_ihc_utils.py b/flamingo_tools/postprocessing/synapse_per_ihc_utils.py new file mode 100644 index 0000000..f301485 --- /dev/null +++ b/flamingo_tools/postprocessing/synapse_per_ihc_utils.py @@ -0,0 +1,50 @@ +SYNAPSE_DICT = { + "M_LR_000226_L": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, + "M_LR_000226_R": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, + "M_LR_000227_L": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, + "M_LR_000227_R": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, + "G_EK_000233_L": {"synapse_table_name": "synapse_v3_ihc_v6", "ihc_table_name": "IHC_v6"}, + "G_LR_000233_R": {"synapse_table_name": "synapse_v3_ihc_v6", "ihc_table_name": "IHC_v6"}, + + # fHC + "M_AMD_N139_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 4]}, + "M_AMD_N139_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [2, 5, 3]}, + "M_AMD_N153_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b"}, + "M_AMD_N153_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [7, 1, 4]}, + + # PELCOfHC1 + "M_AMD_N140_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 5]}, + "M_AMD_N140_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b"}, + "M_AMD_N142_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b"}, + "M_AMD_N142_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [2, 4, 3]}, + + # PELCOfHC2 + "M_AMD_N75_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 3]}, + "M_AMD_N75_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 5, 12, 4, 2, 13, 7]}, + "M_AMD_N129_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [2, 5, 3, 6]}, + "M_AMD_N129_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [4, 2, 6]}, + + # PELCODISCO2 + "M_AMD_N88_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b"}, + "M_AMD_N89_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 2]}, + + # PELCODISCO2 + "M_AMD_N95_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [1, 4, 3, 9, 17, 2]}, + "M_AMD_N95_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [8, 4, 2, 29, 6, 9, 3]}, + "M_AMD_N97_L": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [2, 1, 3]}, + "M_AMD_N97_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", + "component_list": [2, 5]}, +} \ No newline at end of file diff --git a/flamingo_tools/test_data.py b/flamingo_tools/test_data.py index ad43c22..61fa593 100644 --- a/flamingo_tools/test_data.py +++ b/flamingo_tools/test_data.py @@ -10,7 +10,7 @@ from skimage.measure import label from .file_utils import get_cache_dir -from .segmentation.postprocessing import compute_table_on_the_fly +from .postprocessing.label_components import compute_table_on_the_fly SEGMENTATION_URL = "https://owncloud.gwdg.de/index.php/s/kwoGRYiJRRrswgw/download" diff --git a/reproducibility/block_extraction/repro_equidistant_centers.py b/reproducibility/block_extraction/repro_equidistant_centers.py index 1f81ef5..036b7e8 100644 --- a/reproducibility/block_extraction/repro_equidistant_centers.py +++ b/reproducibility/block_extraction/repro_equidistant_centers.py @@ -5,7 +5,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.segmentation.cochlea_mapping import equidistant_centers +from flamingo_tools.postprocessing.cochlea_mapping import equidistant_centers def repro_equidistant_centers( diff --git a/reproducibility/label_components/repro_label_components.py b/reproducibility/label_components/repro_label_components.py index 34912b0..e4f666d 100644 --- a/reproducibility/label_components/repro_label_components.py +++ b/reproducibility/label_components/repro_label_components.py @@ -5,7 +5,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.segmentation.postprocessing import label_components_sgn, label_components_ihc +from flamingo_tools.postprocessing.label_components import label_components_sgn, label_components_ihc def label_custom_components(tsv_table, custom_dict): diff --git a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py index ed0d16c..1700719 100644 --- a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py +++ b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py @@ -5,7 +5,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping +from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping def repro_tonotopic_mapping( diff --git a/scripts/assign_subtypes.py b/scripts/assign_subtypes.py index 8217647..4a364ec 100644 --- a/scripts/assign_subtypes.py +++ b/scripts/assign_subtypes.py @@ -4,7 +4,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT -from flamingo_tools.segmentation.sgn_subtype_utils import STAIN_TO_TYPE, COCHLEAE +from flamingo_tools.postprocessing.sgn_subtype_utils import STAIN_TO_TYPE, COCHLEAE # from skimage.segmentation import relabel_sequential diff --git a/scripts/export_lower_resolution.py b/scripts/export_lower_resolution.py index 5783410..65d0076 100644 --- a/scripts/export_lower_resolution.py +++ b/scripts/export_lower_resolution.py @@ -10,7 +10,7 @@ from elf.parallel import isin from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT -from flamingo_tools.segmentation.postprocessing import filter_cochlea_volume, filter_cochlea_volume_single +from flamingo_tools.postprocessing.label_components import filter_cochlea_volume, filter_cochlea_volume_single # from skimage.segmentation import relabel_sequential diff --git a/scripts/export_lower_resolution_subtypes.py b/scripts/export_lower_resolution_subtypes.py index 96681ed..3a3d915 100644 --- a/scripts/export_lower_resolution_subtypes.py +++ b/scripts/export_lower_resolution_subtypes.py @@ -7,7 +7,7 @@ import zarr from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT -from flamingo_tools.segmentation.sgn_subtype_utils import STAIN_TO_TYPE, COCHLEAE +from flamingo_tools.postprocessing.sgn_subtype_utils import STAIN_TO_TYPE, COCHLEAE # from skimage.segmentation import relabel_sequential diff --git a/scripts/figures/plot_SGNsub_thresholds.py b/scripts/figures/plot_SGNsub_thresholds.py index 3f71ca2..0d3fb4b 100644 --- a/scripts/figures/plot_SGNsub_thresholds.py +++ b/scripts/figures/plot_SGNsub_thresholds.py @@ -8,7 +8,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.segmentation.sgn_subtype_utils import CUSTOM_THRESHOLDS, COCHLEAE, ALIAS +from flamingo_tools.postprocessing.sgn_subtype_utils import CUSTOM_THRESHOLDS, COCHLEAE, ALIAS png_dpi = 300 diff --git a/scripts/figures/plot_fig3.py b/scripts/figures/plot_fig3.py index 966d369..3b01e19 100644 --- a/scripts/figures/plot_fig3.py +++ b/scripts/figures/plot_fig3.py @@ -18,7 +18,7 @@ from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path from util import frequency_mapping, SYNAPSE_DIR_ROOT, custom_formatter_1, average_by_fraction from util import prism_style, prism_cleanup_axes, export_legend, get_marker_handle, get_flatline_handle -from flamingo_tools.segmentation.sgn_subtype_utils import stain_to_type, COCHLEAE, ALIAS +from flamingo_tools.postprocessing.sgn_subtype_utils import stain_to_type, COCHLEAE, ALIAS INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/frequency_mapping/M_LR_000227_R/scale3" SYNAPSE_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses" diff --git a/scripts/figures/plot_supp_fig3.py b/scripts/figures/plot_supp_fig3.py index 75b39f4..3b0e814 100644 --- a/scripts/figures/plot_supp_fig3.py +++ b/scripts/figures/plot_supp_fig3.py @@ -10,9 +10,8 @@ from flamingo_tools.s3_utils import get_s3_path from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay -from flamingo_tools.segmentation.sgn_subtype_utils import COCHLEAE, CUSTOM_THRESHOLDS +from flamingo_tools.postprocessing.sgn_subtype_utils import ALIAS, COCHLEAE, CUSTOM_THRESHOLDS from util import prism_style, prism_cleanup_axes, export_legend, get_flatline_handle -from flamingo_tools.segmentation.sgn_subtype_utils import ALIAS png_dpi = 300 FILE_EXTENSION = "png" diff --git a/scripts/la-vision/segment_sgns.py b/scripts/la-vision/segment_sgns.py index 97c117a..cf7468c 100644 --- a/scripts/la-vision/segment_sgns.py +++ b/scripts/la-vision/segment_sgns.py @@ -3,7 +3,7 @@ import pandas as pd from flamingo_tools.segmentation import run_unet_prediction -from flamingo_tools.segmentation.postprocessing import label_components_sgn +from flamingo_tools.postprocessing.label_components import label_components_sgn from mobie import add_segmentation from mobie.metadata import add_remote_project_metadata diff --git a/scripts/measurements/evaluate_marker_annotations_subtype.py b/scripts/measurements/evaluate_marker_annotations_subtype.py index c79f317..3e1ebef 100644 --- a/scripts/measurements/evaluate_marker_annotations_subtype.py +++ b/scripts/measurements/evaluate_marker_annotations_subtype.py @@ -7,8 +7,8 @@ from flamingo_tools.s3_utils import get_s3_path from flamingo_tools.file_utils import read_image_data -from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations -from flamingo_tools.segmentation.sgn_subtype_utils import CUSTOM_THRESHOLDS, COCHLEAE +from flamingo_tools.postprocessing.chreef_utils import localize_median_intensities, find_annotations +from flamingo_tools.postprocessing.sgn_subtype_utils import CUSTOM_THRESHOLDS, COCHLEAE MARKER_DIR_SUBTYPE = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes" diff --git a/scripts/measurements/measure_synapses.py b/scripts/measurements/measure_synapses.py index 826f021..fae16c9 100644 --- a/scripts/measurements/measure_synapses.py +++ b/scripts/measurements/measure_synapses.py @@ -6,28 +6,10 @@ import pandas as pd from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target - -OUTPUT_FOLDER = "./ihc_counts" - -COCHLEAE = [ - "M_LR_000226_L", - "M_LR_000226_R", - "M_LR_000227_L", - "M_LR_000227_R", - "G_EK_000233_L", - "G_LR_000233_R", - -] +from flamingo_tools.postprocessing.synapse_per_ihc_utils import SYNAPSE_DICT -SYNAPSE_DICT = { - "M_LR_000226_L": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, - "M_LR_000226_R": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, - "M_LR_000227_L": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, - "M_LR_000227_R": {"synapse_table_name": "synapse_v3_ihc_v4c", "ihc_table_name": "IHC_v4c"}, - "G_EK_000233_L": {"synapse_table_name": "synapse_v3_ihc_v6", "ihc_table_name": "IHC_v6"}, - "G_LR_000233_R": {"synapse_table_name": "synapse_v3_ihc_v6", "ihc_table_name": "IHC_v6"}, -} +OUTPUT_FOLDER = "./ihc_counts" def check_project(cochleae, output_folder, plot=False, save_ihc_table=False, max_dist=None): @@ -37,12 +19,12 @@ def check_project(cochleae, output_folder, plot=False, save_ihc_table=False, max if cochlea in SYNAPSE_DICT.keys(): synapse_table_name = SYNAPSE_DICT[cochlea]["synapse_table_name"] ihc_table_name = SYNAPSE_DICT[cochlea]["ihc_table_name"] + component_id = SYNAPSE_DICT[cochlea].get("component_list", [1]) else: - synapse_table_name = "synapse_v3_ihc_v4c" - ihc_table_name = "IHC_v4c" - - component_id = [1] + synapse_table_name = "synapse_v3_ihc_v4b" + ihc_table_name = "IHC_v4b" + component_id = [1] if cochlea == "M_AMD_OTOF1_L": synapse_table_name = "synapse_v3_ihc_v4b" @@ -132,7 +114,8 @@ def main(): description="Assign each segmentation instance a marker based on annotation thresholds." ) - parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=COCHLEAE, help="Cochlea(e) to process.") + parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=list(SYNAPSE_DICT.keys()), + help="Cochlea(e) to process.") parser.add_argument("-o", "--output", type=str, default=OUTPUT_FOLDER, help="Output directory.") args = parser.parse_args() diff --git a/scripts/measurements/sgn_subtypes.py b/scripts/measurements/sgn_subtypes.py index e9e0e9e..e2ffd9b 100644 --- a/scripts/measurements/sgn_subtypes.py +++ b/scripts/measurements/sgn_subtypes.py @@ -9,7 +9,7 @@ from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target from flamingo_tools.measurements import compute_object_measures -from flamingo_tools.segmentation.sgn_subtype_utils import COCHLEAE, stain_to_type +from flamingo_tools.postprocessing.sgn_subtype_utils import COCHLEAE, stain_to_type # Define the animal specific octave bands. diff --git a/scripts/measurements/subtype_overview.py b/scripts/measurements/subtype_overview.py index 20fe4c4..1b7a51f 100644 --- a/scripts/measurements/subtype_overview.py +++ b/scripts/measurements/subtype_overview.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target -from sgn_subtypes import COCHLEAE +from flamingo_tools.postprocessing.sgn_subtype_utils import COCHLEAE def get_overview(cochlea, seg_name, component_ids): diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py index b8a71ff..195e6d0 100644 --- a/scripts/prediction/expand_seg_table.py +++ b/scripts/prediction/expand_seg_table.py @@ -4,7 +4,7 @@ import pandas as pd -import flamingo_tools.segmentation.postprocessing as postprocessing +import flamingo_tools.postprocessing.label_components as postprocessing import flamingo_tools.s3_utils as s3_utils diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 74a255f..5eb8541 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -6,8 +6,8 @@ import flamingo_tools.s3_utils as s3_utils from flamingo_tools.segmentation import filter_segmentation -from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius -from flamingo_tools.segmentation.postprocessing import label_components_sgn +from flamingo_tools.postprocessing.label_components import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius +from flamingo_tools.postprocessing.label_components import label_components_sgn # TODO needs updates diff --git a/scripts/prediction/tonotopic_mapping.py b/scripts/prediction/tonotopic_mapping.py index f5d9f31..930ccd7 100644 --- a/scripts/prediction/tonotopic_mapping.py +++ b/scripts/prediction/tonotopic_mapping.py @@ -3,7 +3,7 @@ import pandas as pd import flamingo_tools.s3_utils as s3_utils -from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping +from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping def main(): diff --git a/scripts/synapse_marker_detection/add_synapse_per_ihc.py b/scripts/synapse_marker_detection/add_synapse_per_ihc.py index 2ffa434..12065df 100644 --- a/scripts/synapse_marker_detection/add_synapse_per_ihc.py +++ b/scripts/synapse_marker_detection/add_synapse_per_ihc.py @@ -5,14 +5,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT - -COCHLEAE = { - "M_LR_000226_L": {"seg_name": "IHC_v4c", "component_list": [1]}, - "M_LR_000226_R": {"seg_name": "IHC_v4c", "component_list": [1]}, - "M_LR_000227_L": {"seg_name": "IHC_v4c", "component_list": [1]}, - "M_LR_000227_R": {"seg_name": "IHC_v4c", "component_list": [1]}, - "M_AMD_OTOF1_L": {"seg_name": "IHC_v4b", "component_list": [3, 11]}, -} +from flamingo_tools.postprocessing.synapse_per_ihc_utils import SYNAPSE_DICT COCHLEA_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" OUT_DIR = f"{COCHLEA_DIR}/mobie_project/cochlea-lightsheet/tables/syn_per_ihc" @@ -28,9 +21,9 @@ def add_syn_per_ihc(args): for cochlea in args.cochlea: if args.seg_version is None: - seg_version = COCHLEAE[cochlea]["seg_name"] + seg_version = SYNAPSE_DICT[cochlea]["ihc_table_name"] else: - seg_version = args.seg_version + seg_version = "IHC_v4b" print(f"Evaluating cochlea {cochlea}.") @@ -38,9 +31,9 @@ def add_syn_per_ihc(args): syn_per_ihc_dir = f"{COCHLEA_DIR}/predictions/synapses/ihc_counts_{ihc_version}" if args.component_list is None: - component_list = COCHLEAE[cochlea]["component_list"] + component_list = SYNAPSE_DICT[cochlea].get("component_list", [1]) else: - component_list = args.component_list + component_list = [1] s3_path = os.path.join(f"{cochlea}", "tables", seg_version, "default.tsv") tsv_path, fs = get_s3_path(s3_path, bucket_name=BUCKET_NAME, @@ -50,6 +43,10 @@ def add_syn_per_ihc(args): # synapse_table syn_path = os.path.join(syn_per_ihc_dir, f"ihc_count_{cochlea}.tsv") + if not os.path.isfile(syn_path): + print(f"Skipping cochlea {cochlea}. Synapse table {syn_path} does not exist.") + continue + with open(syn_path, 'r') as f: syn_table = pd.read_csv(f, sep="\t") @@ -81,7 +78,8 @@ def add_syn_per_ihc(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=COCHLEAE, help="Cochlea(e) to process.") + parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=list(SYNAPSE_DICT.keys()), + help="Cochlea(e) to process.") parser.add_argument("-o", "--output_folder", type=str, default=None, help="Path to output folder.") parser.add_argument("-s", "--seg_version", type=str, default=None, help="Path to output folder.") parser.add_argument("--ihc_syn", action="store_true", help="Consider only IHC with synapses.") diff --git a/test/test_segmentation/test_postprocessing.py b/test/test_segmentation/test_postprocessing.py index c8be572..8b5e8f5 100644 --- a/test/test_segmentation/test_postprocessing.py +++ b/test/test_segmentation/test_postprocessing.py @@ -17,7 +17,7 @@ def _create_example_seg(self, tmp_dir): return seg def _test_postprocessing(self, spatial_statistics, threshold, **spatial_statistics_kwargs): - from flamingo_tools.segmentation.postprocessing import filter_segmentation + from flamingo_tools.postprocessing.label_components import filter_segmentation with tempfile.TemporaryDirectory() as tmp_dir: example_seg = self._create_example_seg(tmp_dir) @@ -33,22 +33,22 @@ def _test_postprocessing(self, spatial_statistics, threshold, **spatial_statisti self.assertEqual(filtered_seg.shape, example_seg.shape) def test_nearest_neighbor_distance(self): - from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance + from flamingo_tools.postprocessing.label_components import nearest_neighbor_distance self._test_postprocessing(nearest_neighbor_distance, threshold=5) def test_local_ripleys_k(self): - from flamingo_tools.segmentation.postprocessing import local_ripleys_k + from flamingo_tools.postprocessing.label_components import local_ripleys_k self._test_postprocessing(local_ripleys_k, threshold=0.5) def test_neighbors_in_radius(self): - from flamingo_tools.segmentation.postprocessing import neighbors_in_radius + from flamingo_tools.postprocessing.label_components import neighbors_in_radius self._test_postprocessing(neighbors_in_radius, threshold=5) def test_compute_table_on_the_fly(self): - from flamingo_tools.segmentation.postprocessing import compute_table_on_the_fly + from flamingo_tools.postprocessing.label_components import compute_table_on_the_fly from flamingo_tools.test_data import get_test_volume_and_segmentation with tempfile.TemporaryDirectory() as tmp_dir: From e0e9484fdd80ccb4ee0a49e0c3e4db3e95390ca0 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Tue, 25 Nov 2025 15:59:14 +0100 Subject: [PATCH 3/5] Re-factoring tonotopic mapping and calculating object measures --- flamingo_tools/measurements.py | 6 +- .../postprocessing/cochlea_mapping.py | 5 + .../postprocessing/label_components.py | 18 +- .../repro_label_components.py | 102 ++++--- .../object_measures/repro_object_measures.py | 258 ++++++++++++------ .../repro_tonotopic_mapping.py | 229 +++++++++++----- scripts/prediction/postprocess_seg.py | 156 ----------- scripts/prediction/tonotopic_mapping.py | 49 ---- 8 files changed, 420 insertions(+), 403 deletions(-) delete mode 100644 scripts/prediction/postprocess_seg.py delete mode 100644 scripts/prediction/tonotopic_mapping.py diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py index 528de37..12ba885 100644 --- a/flamingo_tools/measurements.py +++ b/flamingo_tools/measurements.py @@ -28,7 +28,7 @@ def _measure_volume_and_surface(mask, resolution): # Use marching_cubes for 3D data - verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3) + verts, faces, normals, _ = marching_cubes(mask, spacing=resolution) mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) surface = mesh.area @@ -166,6 +166,8 @@ def _default_object_features( # Do the volume and surface measurement. if not median_only: + if isinstance(resolution, float): + resolution = (resolution,) * 3 volume, surface = _measure_volume_and_surface(mask, resolution) measures["volume"] = volume measures["surface"] = surface @@ -181,6 +183,8 @@ def _morphology_features(seg_id, table, image, segmentation, resolution, **kwarg # Hard-coded value for LaVision cochleae. This is a hack for the wrong voxel size in MoBIE. # resolution = (3.0, 0.76, 0.76) + if isinstance(resolution, float): + resolution = (resolution,) * 3 volume, surface = _measure_volume_and_surface(mask, resolution) measures["volume"] = volume measures["surface"] = surface diff --git a/flamingo_tools/postprocessing/cochlea_mapping.py b/flamingo_tools/postprocessing/cochlea_mapping.py index b6df6ef..628a6be 100644 --- a/flamingo_tools/postprocessing/cochlea_mapping.py +++ b/flamingo_tools/postprocessing/cochlea_mapping.py @@ -528,6 +528,7 @@ def map_frequency(table: pd.DataFrame, animal: str = "mouse", otof: bool = False Args: table: Dataframe containing the segmentation. animal: Select the Greenwood function parameters specific to a species. Either "mouse" or "gerbil". + otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. Returns: Dataframe containing frequency in an additional column 'frequency[kHz]'. @@ -757,6 +758,10 @@ def tonotopic_mapping( component_label: List of component labels to evaluate. components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal. cell_type: Cell type of segmentation. + animal: Animal specifier for species specific frequency mapping. Either "mouse" or "gerbil". + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. + apex_higher: Flag for identifying apex and base. Apex is set to node with higher y-value if True. + otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. Returns: Table with tonotopic label for cells. diff --git a/flamingo_tools/postprocessing/label_components.py b/flamingo_tools/postprocessing/label_components.py index a585315..56ed204 100644 --- a/flamingo_tools/postprocessing/label_components.py +++ b/flamingo_tools/postprocessing/label_components.py @@ -170,15 +170,15 @@ def filter_segmentation( In addition, objects smaller than a given size are filtered out. Args: - segmentation: Dataset containing the segmentation - output_path: Output path for postprocessed segmentation - spatial_statistics: Function to calculate density measure for elements of segmentation - threshold: Distance in micrometer to check for neighbors - min_size: Minimal number of pixels for filtering small instances - table: Dataframe of segmentation table - resolution: Resolution of segmentation in micrometer - output_key: Output key for postprocessed segmentation - spatial_statistics_kwargs: Arguments for spatial statistics function + segmentation: Dataset containing the segmentation. + output_path: Output path for postprocessed segmentation. + spatial_statistics: Function to calculate density measure for elements of segmentation. + threshold: Distance in micrometer to check for neighbors. + min_size: Minimal number of pixels for filtering small instances. + table: Dataframe of segmentation table. + resolution: Resolution of segmentation in micrometer. + output_key: Output key for postprocessed segmentation. + spatial_statistics_kwargs: Arguments for spatial statistics function. Returns: The number of objects before filtering. diff --git a/reproducibility/label_components/repro_label_components.py b/reproducibility/label_components/repro_label_components.py index e4f666d..602732a 100644 --- a/reproducibility/label_components/repro_label_components.py +++ b/reproducibility/label_components/repro_label_components.py @@ -77,6 +77,7 @@ def _load_json_as_list(ddict_path: str) -> List[dict]: def label_components_single( table_path: str, out_path: str, + force_overwrite: bool = False, cell_type: str = "sgn", component_list: List[int] = [1], max_edge_distance: float = 30, @@ -89,45 +90,81 @@ def label_components_single( custom_dic: Optional[dict] = None, **_ ): - """Process a single cochlea using one set of parameters or a custom_dic. + """Process a single cochlea using one set of parameters or a custom dictionary. + The cochlea is analyzed using graph-connected components + to label segmentation instances that are closer than a given maximal edge distance. + This process acts on an input segmentation table to which a "component_labels" column is added. + Each entry in this column refers to the index of a connected component. + The largest connected component has an index of 1; the others follow in decreasing order. + + Args: + table_path: File path to segmentation table. + out_path: Output path to segmentation table with new column "component_labels". + force_overwrite: Forcefully overwrite existing output path. + cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". + component_list: List of components. Can be passed to obtain the number of instances within the component list. + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. + min_component_length: Minimal length of nodes of connected component. Filtered out if lower. + min_size: Minimal number of pixels for filtering small instances. + s3: Use S3 bucket. + s3_credentials: + s3_bucket_name: + s3_service_endpoint: + custom_dic: Custom dictionary which allows multiple post-processing configurations and combines the + results into final components. """ + if os.path.isdir(out_path): + raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") + if s3: tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, "r") as f: - table = pd.read_csv(f, sep="\t") + with fs.open(tsv_path, "r") as f: + table = pd.read_csv(f, sep="\t") + else: + table = pd.read_csv(table_path, sep="\t") + + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True + + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") - if custom_dic is not None: - tsv_table = label_custom_components(table, custom_dic) else: - if cell_type == "sgn": - tsv_table = label_components_sgn(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) - elif cell_type == "ihc": - tsv_table = label_components_ihc(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) + if custom_dic is not None: + # use multiple post-processing configurations + tsv_table = label_custom_components(table, custom_dic) else: - raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") + if cell_type == "sgn": + tsv_table = label_components_sgn(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + elif cell_type == "ihc": + tsv_table = label_components_ihc(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + else: + raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") - custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) - print(f"Total {cell_type.upper()}s: {len(tsv_table)}") - if component_list == [1]: - print(f"Largest component has {custom_comp} {cell_type.upper()}s.") - else: - for comp in component_list: - num_instances = len(tsv_table[tsv_table["component_labels"] == comp]) - print(f"Component {comp} has {num_instances} instances.") - print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") + custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) + print(f"Total {cell_type.upper()}s: {len(tsv_table)}") + if component_list == [1]: + print(f"Largest component has {custom_comp} {cell_type.upper()}s.") + else: + for comp in component_list: + num_instances = len(tsv_table[tsv_table["component_labels"] == comp]) + print(f"Component {comp} has {num_instances} instances.") + print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") - tsv_table.to_csv(out_path, sep="\t", index=False) + tsv_table.to_csv(out_path, sep="\t", index=False) -def repro_label_components( +def wrapper_label_components( output_path: str, table_path: Optional[str] = None, ddict: Optional[str] = None, + s3: bool = False, **kwargs ): """Wrapper function for labeling connected components using a segmentation table. @@ -135,7 +172,7 @@ def repro_label_components( and the explicit setting of parameters. """ if ddict is None: - label_components_single(table_path, output_path, **kwargs) + label_components_single(table_path, output_path, s3=s3, **kwargs) else: param_dicts = _load_json_as_list(ddict) for params in param_dicts: @@ -151,7 +188,8 @@ def repro_label_components( save_path = os.path.join(output_path, "_".join([cochlea_str, f"{table_str}.tsv"])) else: save_path = output_path - label_components_single(table_path=table_path, out_path=save_path, **params) + label_components_single(table_path=table_path, out_path=save_path, s3=s3, + **params) def main(): @@ -159,15 +197,14 @@ def main(): description="Script to label segmentation using a segmentation table and graph connected components.") parser.add_argument("-o", "--output", type=str, required=True, - help="Output path. Either directory or specific file.") - + help="Output path. Either directory (for --json) or specific file otherwise.") parser.add_argument("-i", "--input", type=str, default=None, help="Input path to segmentation table.") parser.add_argument("-j", "--json", type=str, default=None, help="Input JSON dictionary.") + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + # options for post-processing parser.add_argument("--cell_type", type=str, default="sgn", help="Cell type of segmentation. Either 'sgn' or 'ihc'.") - - # options for post-processing parser.add_argument("--min_size", type=int, default=1000, help="Minimal number of pixels for filtering small instances.") parser.add_argument("--min_component_length", type=int, default=50, @@ -188,7 +225,7 @@ def main(): args = parser.parse_args() - repro_label_components( + wrapper_label_components( output_path=args.output, table_path=args.input, ddict=args.json, @@ -197,6 +234,7 @@ def main(): max_edge_distance=args.max_edge_distance, min_component_length=args.min_component_length, min_size=args.min_size, + force_overwrite=args.force, s3=args.s3, s3_credentials=args.s3_credentials, s3_bucket_name=args.s3_bucket_name, diff --git a/reproducibility/object_measures/repro_object_measures.py b/reproducibility/object_measures/repro_object_measures.py index 614a254..26b8169 100644 --- a/reproducibility/object_measures/repro_object_measures.py +++ b/reproducibility/object_measures/repro_object_measures.py @@ -2,112 +2,195 @@ import json import os from multiprocessing import cpu_count -from typing import Optional +from typing import List, Optional import numpy as np import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.s3_utils import MOBIE_FOLDER from flamingo_tools.measurements import compute_object_measures, compute_sgn_background_mask -def repro_object_measures( - json_file: str, - output_dir: str, +def _load_json_as_list(ddict_path: str) -> List[dict]: + with open(ddict_path, "r") as f: + data = json.loads(f.read()) + # ensure the result is always a list + return data if isinstance(data, list) else [data] + + +def object_measures_single( + table_path: str, + seg_path: str, + image_paths: List[str], + out_paths: List[str], force_overwrite: bool = False, + component_list: List[int] = [1], + background_mask: Optional[np.typing.ArrayLike] = None, + resolution: List[float] = [0.38, 0.38, 0.38], + s3: bool = False, s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, + **_ ): - s3_flag = True input_key = "s0" - default_resolution = 0.38 - default_component_list = [1] - default_bg_mask = None - - with open(json_file, "r") as myfile: - data = myfile.read() - param_dicts = json.loads(data) - - for dic in param_dicts: - cochlea = dic["cochlea"] - image_channels = dic["image_channel"] if isinstance(dic["image_channel"], list) else [dic["image_channel"]] - seg_channel = dic["segmentation_channel"] - resolution = tuple(dic["resolution"]) if "resolution" in dic else default_resolution - component_list = dic["component_list"] if "component_list" in dic else default_component_list - bg_mask = dic["background_mask"] if "background_mask" in dic else default_bg_mask - print(f"Processing cochlea {cochlea}") - - if not isinstance(resolution, float): - assert len(resolution) == 3 - resolution = np.array(resolution)[::-1] - - for img_channel in image_channels: - - print(f"Processing image channel {img_channel}") - cochlea_str = "-".join(cochlea.split("_")) - img_str = "-".join(img_channel.split("_")) - seg_str = "-".join(seg_channel.split("_")) - output_table_path = os.path.join(output_dir, f"{cochlea_str}_{img_str}_{seg_str}_object-measures.tsv") - - img_s3 = f"{cochlea}/images/ome-zarr/{img_channel}.ome.zarr" - seg_s3 = f"{cochlea}/images/ome-zarr/{seg_channel}.ome.zarr" - seg_table_s3 = f"{cochlea}/tables/{seg_channel}/default.tsv" - - img_path, fs = s3_utils.get_s3_path(img_s3, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - seg_path, fs = s3_utils.get_s3_path(seg_s3, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - - n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count())) - if os.path.isfile(output_table_path) and not force_overwrite: - print(f"Skipping creation of {output_table_path}. File already exists.") + if not isinstance(resolution, float): + if len(resolution) == 1: + resolution = resolution * 3 + assert len(resolution) == 3 + resolution = np.array(resolution)[::-1] + else: + resolution = (resolution,) * 3 + + for (img_path, out_path) in zip(image_paths, out_paths): + n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count())) + + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True + + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") + + else: + if background_mask is None: + feature_set = "default" + dilation = None + median_only = False else: - if bg_mask is None: - feature_set = "default" - dilation = None - median_only = False - else: - print("Using background mask for calculating object measures.") - feature_set = "default_background_subtract" - dilation = 4 - median_only = True - mask_cache_path = os.path.join(output_dir, f"{cochlea_str}_{img_str}_{seg_str}_bg-mask.zarr") - bg_mask = compute_sgn_background_mask( - image_path=img_path, - segmentation_path=seg_path, - image_key=input_key, - segmentation_key=input_key, - n_threads=n_threads, - cache_path=mask_cache_path, - ) - - compute_object_measures( - image_path=img_s3, - segmentation_path=seg_s3, - segmentation_table_path=seg_table_s3, - output_table_path=output_table_path, + print("Using background mask for calculating object measures.") + feature_set = "default_background_subtract" + dilation = 4 + median_only = True + + if s3: + img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, + credential_file=s3_credentials) + seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, + credential_file=s3_credentials) + + mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr") + background_mask = compute_sgn_background_mask( + image_path=img_path, + segmentation_path=seg_path, image_key=input_key, segmentation_key=input_key, - feature_set=feature_set, - s3_flag=s3_flag, - component_list=component_list, - dilation=dilation, - median_only=median_only, - background_mask=bg_mask, n_threads=n_threads, - resolution=resolution, + cache_path=mask_cache_path, ) + compute_object_measures( + image_path=img_path, + segmentation_path=seg_path, + segmentation_table_path=table_path, + output_table_path=out_path, + image_key=input_key, + segmentation_key=input_key, + feature_set=feature_set, + s3_flag=s3, + component_list=component_list, + dilation=dilation, + median_only=median_only, + background_mask=background_mask, + n_threads=n_threads, + resolution=resolution, + ) + + +def wrapper_object_measures( + out_paths: List[str], + image_paths: Optional[List[str]] = None, + table_path: Optional[str] = None, + seg_path: Optional[str] = None, + ddict: Optional[str] = None, + force_overwrite: bool = False, + s3: bool = False, + **kwargs +): + """Wrapper function for calculationg object measures for different image channels using a segmentation table. + The function is used to distinguish between a passed parameter dictionary in JSON format + and the explicit setting of parameters. + + Args: + output_paths: Output path(s) to table containing object measures. + image_paths: Path(s) to one or multiple image channels in ome.zarr format. + table_path: File path to segmentation table. + seg_path: Input path to segmentation channel in ome.zarr format. + ddict: Data dictionary containing parameters for tonotopic mapping. + force_overwrite: Forcefully overwrite existing output path. + s3: Use S3 bucket. + """ + out_paths = [os.path.realpath(o) for o in out_paths] + if ddict is None: + object_measures_single(table_path, seg_path, image_paths, out_paths, force_overwrite=force_overwrite, + s3=s3, **kwargs) + + else: + param_dicts = _load_json_as_list(ddict) + for num, params in enumerate(param_dicts): + cochlea = params["cochlea"] + print(f"\n{cochlea}") + seg_channel = params["segmentation_channel"] + image_channels = params["image_channel"] + table_path = os.path.join(f"{cochlea}", "tables", seg_channel, "default.tsv") + if len(out_paths) == 1 and os.path.isdir(out_paths[0]): + + c_str = "-".join(cochlea.split("_")) + s_str = "-".join(seg_channel.split("_")) + out_paths_tmp = [] + for img_channel in image_channels: + i_str = "-".join(img_channel.split("_")) + out_paths_tmp.append(os.path.join(out_paths[0], f"{c_str}_{i_str}_{s_str}_object-measures.tsv")) -def main(): - parser = argparse.ArgumentParser( - description="Script to extract region of interest (ROI) block around center coordinate.") + else: + assert len(image_channels) == len(out_paths) + out_paths_tmp = out_paths.copy() + + if s3: + image_paths = [f"{cochlea}/images/ome-zarr/{ch}.ome.zarr" for ch in image_channels] + seg_path = f"{cochlea}/images/ome-zarr/{seg_channel}.ome.zarr" + seg_table = f"{cochlea}/tables/{seg_channel}/default.tsv" + else: + image_paths = [f"{MOBIE_FOLDER}/{cochlea}/images/ome-zarr/{ch}.ome.zarr" for ch in image_channels] + seg_path = f"{MOBIE_FOLDER}/{cochlea}/images/ome-zarr/{seg_channel}.ome.zarr" + seg_table = f"{MOBIE_FOLDER}/{cochlea}/tables/{seg_channel}/default.tsv" + + object_measures_single( + table_path=seg_table, + seg_path=seg_path, + image_paths=image_paths, + out_paths=out_paths_tmp, + force_overwrite=force_overwrite, + s3=s3, + **params, + ) - parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") - parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") +def main(): + parser = argparse.ArgumentParser( + description="Script to compute object measures for different stainings.") + + parser.add_argument("-o", "--output", type=str, nargs="+", required=True, + help="Output path(s). Either directory or specific file(s).") + parser.add_argument("-i", "--image_paths", type=str, nargs="+", default=None, + help="Input path to one or multiple image channels in ome.zarr format.") + parser.add_argument("-t", "--seg_table", type=str, default=None, + help="Input path to segmentation table.") + parser.add_argument("-s", "--seg_path", type=str, default=None, + help="Input path to segmentation channel in ome.zarr format.") + parser.add_argument("-j", "--json", type=str, default=None, help="Input JSON dictionary.") parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + # options for object measures + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of components.") + parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[0.38, 0.38, 0.38], + help="Resolution of input in micrometer.") + parser.add_argument("--bg_mask", action="store_true", help="Use background mask for calculating object measures.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") @@ -118,9 +201,14 @@ def main(): args = parser.parse_args() - repro_object_measures( - args.input, args.output, args.force, - args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + wrapper_object_measures( + out_paths=args.output, + image_paths=args.image_paths, + table_path=args.seg_table, + seg_path=args.seg_path, + ddict=args.json, + force_overwrite=args.force, + s3=args.s3, ) diff --git a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py index 1700719..0f39a5e 100644 --- a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py +++ b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py @@ -1,98 +1,174 @@ import argparse import json import os -from typing import Optional +from typing import List, Optional import pandas as pd from flamingo_tools.s3_utils import get_s3_path from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping -def repro_tonotopic_mapping( - ddict: dict, - output_dir: str, +def _load_json_as_list(ddict_path: str) -> List[dict]: + with open(ddict_path, "r") as f: + data = json.loads(f.read()) + # ensure the result is always a list + return data if isinstance(data, list) else [data] + + +def tonotopic_mapping_single( + table_path: str, + out_path: str, + force_overwrite: bool = False, + cell_type: str = "sgn", + animal: str = "mouse", + otof: bool = False, + apex_position: str = "apex_higher", + component_list: List[int] = [1], + component_mapping: Optional[List[int]] = None, + max_edge_distance: float = 30, + s3: bool = False, s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, - force_overwrite: Optional[bool] = None, + **_ ): - default_cell_type = "ihc" - default_apex_position = "apex_higher" - default_max_edge_distance = 30 - default_component_list = [1] - - remove_columns = ["tonotopic_label", - "tonotopic_value[kHz]", - "distance_to_path[µm]", - "length_fraction", - "run_length[µm]", - "centrality"] - - with open(ddict, 'r') as myfile: - data = myfile.read() - param_dicts = json.loads(data) - - for dic in param_dicts: - cochlea = dic["cochlea"] - seg_channel = dic["segmentation_channel"] - if "OTOF" in cochlea: - otof = True - else: - otof = False - - if cochlea[0] in ["M", "m"]: - animal = "mouse" - elif cochlea[0] in ["G", "g"]: - animal = "gerbil" - else: - animal = "mouse" - # raise ValueError("Cochlea does not have expected name format 'M_[...]' or 'G_[...]'.") - - cochlea_str = "-".join(cochlea.split("_")) - seg_str = "-".join(seg_channel.split("_")) - os.makedirs(output_dir, exist_ok=True) - output_table_path = os.path.join(output_dir, f"{cochlea_str}_{seg_str}.tsv") - - s3_path = os.path.join(f"{cochlea}", "tables", f"{seg_channel}", "default.tsv") - print(f"Tonotopic mapping for {cochlea}.") - - tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, + """Tonotopic mapping of a single cochlea. + Each segmentation instance within a given component list is assigned a frequency[kHz], a run length and an offset. + The components used for the mapping itself can be a subset of the component list to adapt to broken components + along the Rosenthal's canal. + If the cochlea is broken in the direction of the Rosenthal's canal, the components have to be provided in a + continuous order which reflects the positioning within 3D. + The frequency is calculated using the Greenwood function using animal specific parameters. + The orientation of the mapping can be reversed using the apex position in reference to the y-coordinate. + + Args: + table_path: File path to segmentation table. + out_path: Output path to segmentation table with new column "component_labels". + force_overwrite: Forcefully overwrite existing output path. + cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". + animal: Animal for species specific frequency mapping. Either "mouse" or "gerbil". + otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. + apex_position: Identify position of apex and base. Apex is set to node with higher y-value per default. + component_list: List of components. Can be passed to obtain the number of instances within the component list. + components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal. + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. + s3: Use S3 bucket. + s3_credentials: + s3_bucket_name: + s3_service_endpoint: + """ + if os.path.isdir(out_path): + raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") + + if s3: + tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, 'r') as f: + with fs.open(tsv_path, "r") as f: table = pd.read_csv(f, sep="\t") + else: + table = pd.read_csv(table_path, sep="\t") - cell_type = dic["type"] if "type" in dic else default_cell_type - component_list = dic["component_list"] if "component_list" in dic else default_component_list - component_mapping = dic["component_mapping"] if "component_mapping" in dic else component_list - apex_position = dic["apex_position"] if "apex_position" in dic else default_apex_position - max_edge_distance = dic["max_edge_distance"] if "max_edge_distance" in dic else default_max_edge_distance + apex_higher = (apex_position == "apex_higher") - apex_higher = (apex_position == "apex_higher") + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True - for column in remove_columns: - if column in list(table.columns): - table = table.drop(column, axis=1) + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") - if not os.path.isfile(output_table_path) or force_overwrite: - table = tonotopic_mapping(table, component_label=component_list, animal=animal, - cell_type=cell_type, component_mapping=component_mapping, - apex_higher=apex_higher, max_edge_distance=max_edge_distance, - otof=otof) + else: + table = tonotopic_mapping(table, component_label=component_list, animal=animal, + cell_type=cell_type, component_mapping=component_mapping, + apex_higher=apex_higher, max_edge_distance=max_edge_distance, + otof=otof) - table.to_csv(output_table_path, sep="\t", index=False) + table.to_csv(out_path, sep="\t", index=False) - else: - print(f"Skipping {output_table_path}. Table already exists.") + +def wrapper_tonotopic_mapping( + output_path: str, + table_path: Optional[str] = None, + ddict: Optional[str] = None, + force_overwrite: bool = False, + animal: str = "mouse", + otof: bool = False, + s3: bool = False, + **kwargs +): + """Wrapper function for tonotopic mapping using a segmentation table. + The function is used to distinguish between a passed parameter dictionary in JSON format + and the explicit setting of parameters. + + Args: + output_path: Output path to segmentation table with new column "component_labels". + table_path: File path to segmentation table. + ddict: Data dictionary containing parameters for tonotopic mapping. + force_overwrite: Forcefully overwrite existing output path. + animal: Animal specifier for species specific frequency mapping. Either "mouse" or "gerbil". + otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. + """ + if ddict is None: + tonotopic_mapping_single(table_path, output_path, animal=animal, ototf=otof, force_overwrite=force_overwrite, + s3=s3, **kwargs) + else: + param_dicts = _load_json_as_list(ddict) + for params in param_dicts: + + cochlea = params["cochlea"] + print(f"\n{cochlea}") + seg_channel = params["segmentation_channel"] + table_path = os.path.join(f"{cochlea}", "tables", seg_channel, "default.tsv") + + if "OTOF" in cochlea: + otof = True + else: + otof = False + + if cochlea[0] in ["M", "m"]: + animal = "mouse" + elif cochlea[0] in ["G", "g"]: + animal = "gerbil" + else: + animal = "mouse" + + if os.path.isdir(output_path): + cochlea_str = "-".join(cochlea.split("_")) + table_str = "-".join(seg_channel.split("_")) + save_path = os.path.join(output_path, "_".join([cochlea_str, f"{table_str}.tsv"])) + else: + save_path = output_path + + tonotopic_mapping_single(table_path=table_path, out_path=save_path, animal=animal, otof=otof, + force_overwrite=force_overwrite, s3=s3, **params) def main(): parser = argparse.ArgumentParser( description="Script to extract region of interest (ROI) block around center coordinate.") - parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") - parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") - + parser.add_argument("-o", "--output", type=str, required=True, + help="Output path. Either directory or specific file.") + parser.add_argument("-i", "--input", type=str, default=None, help="Input path to segmentation table.") + parser.add_argument("-j", "--json", type=str, default=None, help="Input JSON dictionary.") parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + + # options for tonotopic mapping + parser.add_argument("--animal", type=str, default="mouse", + help="Animyl type to be used for frequency mapping. Either 'mouse' or 'gerbil'.") + parser.add_argument("--otof", action="store_true", help="Use frequency mapping for OTOF cochleae.") + parser.add_argument("--apex_position", type=str, default="apex_higher", + help="Use frequency mapping for OTOF cochleae.") + + # options for post-processing + parser.add_argument("--cell_type", type=str, default="sgn", + help="Cell type of segmentation. Either 'sgn' or 'ihc'.") + parser.add_argument("--max_edge_distance", type=float, default=30, + help="Maximal distance in micrometer between points to create edges for connected components.") + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of connected components.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") @@ -103,10 +179,21 @@ def main(): args = parser.parse_args() - repro_tonotopic_mapping( - args.input, args.output, - args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, - args.force, + wrapper_tonotopic_mapping( + output_path=args.output, + table_path=args.input, + ddict=args.json, + force_overwrite=args.force, + cell_type=args.cell_type, + animal=args.animal, + otof=args.otof, + max_edge_distance=args.max_edge_distance, + component_list=args.components, + apex_position=args.apex_position, + s3=args.s3, + s3_credentials=args.s3_credentials, + s3_bucket_name=args.s3_bucket_name, + s3_service_endpoint=args.s3_service_endpoint, ) diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py deleted file mode 100644 index 5eb8541..0000000 --- a/scripts/prediction/postprocess_seg.py +++ /dev/null @@ -1,156 +0,0 @@ -import argparse -import os - -import pandas as pd -import zarr - -import flamingo_tools.s3_utils as s3_utils -from flamingo_tools.segmentation import filter_segmentation -from flamingo_tools.postprocessing.label_components import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius -from flamingo_tools.postprocessing.label_components import label_components_sgn - - -# TODO needs updates -def main(): - - parser = argparse.ArgumentParser( - description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") - - parser.add_argument("-o", "--output_folder", type=str, default=None) - - parser.add_argument("-t", "--tsv", type=str, default=None, - help="TSV-file in MoBIE format which contains information about segmentation.") - parser.add_argument("--tsv_out", type=str, default=None, - help="File path to save post-processed dataframe. Default: default.tsv") - - parser.add_argument('-k', "--input_key", type=str, default="segmentation", - help="The key / internal path of the segmentation.") - parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", - help="The key / internal path of the output.") - parser.add_argument('-r', "--resolution", type=float, default=0.38, - help="Resolution of segmentation in micrometer.") - - # options for post-processing - parser.add_argument("--min_size", type=int, default=1000, - help="Minimal number of pixels for filtering small instances.") - parser.add_argument("--threshold", type=float, default=None, - help="Threshold for spatial statistics.") - parser.add_argument("--min_component_length", type=int, default=50, - help="Minimal length for filtering out connected components.") - parser.add_argument("--max_edge_dist", type=float, default=30, - help="Maximal distance in micrometer between points to create edges for connected components.") - parser.add_argument("--iterations_erode", type=int, default=None, - help="Number of iterations for erosion, normally determined automatically.") - - # options for S3 bucket - parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") - parser.add_argument("--s3_credentials", type=str, default=None, - help="Input file containing S3 credentials. " - "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") - parser.add_argument("--s3_bucket_name", type=str, default=None, - help="S3 bucket name. Optional if BUCKET_NAME was exported.") - parser.add_argument("--s3_service_endpoint", type=str, default=None, - help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") - - # options for spatial statistics - parser.add_argument("--n_neighbors", type=int, default=None, - help="Value for calculating distance to 'n' nearest neighbors.") - parser.add_argument("--local_ripley_radius", type=int, default=None, - help="Value for radius for calculating local Ripley's K function.") - parser.add_argument("--r_neighbors", type=int, default=None, - help="Value for radius for calculating number of neighbors in range.") - - args = parser.parse_args() - - if args.output_folder is None and args.tsv is None: - raise ValueError("Either supply an output folder containing 'segmentation.zarr' or a TSV-file in MoBIE format.") - - # check output folder - if args.output_folder is not None: - seg_path = os.path.join(args.output_folder, "segmentation.zarr") - if args.s3: - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, - service_endpoint=args.s3_service_endpoint, - credential_file=args.s3_credentials) - with zarr.open(s3_path, mode="r") as f: - segmentation = f[args.input_key] - else: - with zarr.open(seg_path, mode="r") as f: - segmentation = f[args.input_key] - else: - seg_path = None - - # check input for spatial statistics - postprocess_functions = [nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius] - function_keywords = ["n_neighbors", "radius", "radius"] - postprocess_options = [args.n_neighbors, args.local_ripley_radius, args.r_neighbors] - default_thresholds = [args.threshold for _ in postprocess_functions] - - if seg_path is not None and args.threshold is None: - default_thresholds = [15, 20, 20] - - def create_spatial_statistics_dict(functions, keyword, options, threshold): - spatial_statistics_dict = [] - for f, o, k, t in zip(functions, keyword, options, threshold): - dic = {"function": f, "keyword": k, "argument": o, "threshold": t} - spatial_statistics_dict.append(dic) - return spatial_statistics_dict - - spatial_statistics_dict = create_spatial_statistics_dict(postprocess_functions, postprocess_options, - function_keywords, default_thresholds) - if seg_path is not None: - if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: - raise ValueError("Choose a postprocess function: 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") - elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: - raise ValueError("The script only supports a single postprocess function.") - else: - for d in spatial_statistics_dict: - if d["argument"] is not None: - spatial_statistics = d["function"] - spatial_statistics_kwargs = {d["keyword"]: d["argument"]} - threshold = d["threshold"] - - # check TSV-file containing data in MoBIE format - tsv_table = None - if args.tsv is not None: - if args.s3: - tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, - service_endpoint=args.s3_service_endpoint, - credential_file=args.s3_credentials) - with fs.open(tsv_path, 'r') as f: - tsv_table = pd.read_csv(f, sep="\t") - else: - with open(args.tsv, 'r') as f: - tsv_table = pd.read_csv(f, sep="\t") - - if seg_path is None: - post_table = label_components_sgn( - tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold, - min_component_length=args.min_component_length, max_edge_distance=args.max_edge_dist, - iterations_erode=args.iterations_erode, - ) - - if args.tsv_out is None: - out_path = "default.tsv" - else: - out_path = args.tsv_out - post_table.to_csv(out_path, sep="\t", index=False) - - n_pre = len(tsv_table) - n_post = len(post_table["component_labels"][post_table["component_labels"] == 1]) - - print(f"Number of pre-filtered objects: {n_pre}\nNumber of objects in largest component: {n_post}") - - else: - n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, - spatial_statistics=spatial_statistics, - threshold=threshold, - min_size=args.min_size, table=tsv_table, - resolution=args.resolution, - output_key=args.output_key, **spatial_statistics_kwargs) - - print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") - - -if __name__ == "__main__": - main() diff --git a/scripts/prediction/tonotopic_mapping.py b/scripts/prediction/tonotopic_mapping.py deleted file mode 100644 index 930ccd7..0000000 --- a/scripts/prediction/tonotopic_mapping.py +++ /dev/null @@ -1,49 +0,0 @@ -import argparse - -import pandas as pd - -import flamingo_tools.s3_utils as s3_utils -from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping - - -def main(): - - parser = argparse.ArgumentParser( - description="Script for the tonotopic mapping of IHCs and SGNs. " - "Either locally or on an S3 bucket.") - - parser.add_argument("-i", "--input", required=True, help="Input table with IHC segmentation.") - parser.add_argument("-o", "--output", required=True, help="Output path for json file with cropping parameters.") - - parser.add_argument("-t", "--type", type=str, default="sgn", help="Cell type of segmentation.") - - parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") - parser.add_argument("--s3_credentials", type=str, default=None, - help="Input file containing S3 credentials. " - "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") - parser.add_argument("--s3_bucket_name", type=str, default=None, - help="S3 bucket name. Optional if BUCKET_NAME was exported.") - parser.add_argument("--s3_service_endpoint", type=str, default=None, - help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") - - args = parser.parse_args() - - if args.s3: - tsv_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name, - service_endpoint=args.s3_service_endpoint, - credential_file=args.s3_credentials) - with fs.open(tsv_path, 'r') as f: - tsv_table = pd.read_csv(f, sep="\t") - else: - with open(args.input, 'r') as f: - tsv_table = pd.read_csv(f, sep="\t") - - table = tonotopic_mapping( - tsv_table, cell_type=args.type, - ) - - table.to_csv(args.output, sep="\t", index=False) - - -if __name__ == "__main__": - main() From e19f8a7c7c526d2f1c597160ca6fa423c72b7fb3 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 26 Nov 2025 09:39:11 +0100 Subject: [PATCH 4/5] Re-locate functions --- flamingo_tools/measurements.py | 100 ++++++++++++ .../postprocessing/cochlea_mapping.py | 77 ++++++++- .../postprocessing/label_components.py | 147 +++++++++++++++++ .../postprocessing/synapse_per_ihc_utils.py | 2 +- .../repro_label_components.py | 149 +----------------- .../object_measures/repro_object_measures.py | 89 +---------- .../repro_tonotopic_mapping.py | 75 +-------- 7 files changed, 327 insertions(+), 312 deletions(-) diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py index 12ba885..4b8251b 100644 --- a/flamingo_tools/measurements.py +++ b/flamingo_tools/measurements.py @@ -6,6 +6,7 @@ import warnings from concurrent import futures from functools import partial +from multiprocessing import cpu_count from typing import List, Optional, Tuple, Union import numpy as np @@ -502,3 +503,102 @@ def _compute_block(block_id): mask = ResizedVolume(low_res_mask, shape=original_shape, order=0) return mask + + +def object_measures_single( + table_path: str, + seg_path: str, + image_paths: List[str], + out_paths: List[str], + force_overwrite: bool = False, + component_list: List[int] = [1], + background_mask: Optional[np.typing.ArrayLike] = None, + resolution: List[float] = [0.38, 0.38, 0.38], + s3: bool = False, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + **_ +): + """Compute object measures for a single or multiple image channels in respect to a single segmentation channel. + + Args: + table_path: File path to segmentationt table. + seg_path: Path to segmentation channel in ome.zarr format. + image_paths: Path(s) to image channel(s) in ome.zarr format. + out_paths: Paths(s) for calculated object measures. + force_overwrite: Forcefully overwrite existing files. + component_list: Only calculate object measures for specific components. + background_mask: Use background mask for calculating object measures. + resolution: Resolution of input in micrometer. + s3: Use S3 file paths. + s3_credentials: + s3_bucket_name: + s3_service_endpoint: + """ + input_key = "s0" + out_paths = [os.path.realpath(o) for o in out_paths] + + if not isinstance(resolution, float): + if len(resolution) == 1: + resolution = resolution * 3 + assert len(resolution) == 3 + resolution = np.array(resolution)[::-1] + else: + resolution = (resolution,) * 3 + + for (img_path, out_path) in zip(image_paths, out_paths): + n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count())) + + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True + + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") + + else: + if background_mask is None: + feature_set = "default" + dilation = None + median_only = False + else: + print("Using background mask for calculating object measures.") + feature_set = "default_background_subtract" + dilation = 4 + median_only = True + + if s3: + img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, + credential_file=s3_credentials) + seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, + credential_file=s3_credentials) + + mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr") + background_mask = compute_sgn_background_mask( + image_path=img_path, + segmentation_path=seg_path, + image_key=input_key, + segmentation_key=input_key, + n_threads=n_threads, + cache_path=mask_cache_path, + ) + + compute_object_measures( + image_path=img_path, + segmentation_path=seg_path, + segmentation_table_path=table_path, + output_table_path=out_path, + image_key=input_key, + segmentation_key=input_key, + feature_set=feature_set, + s3_flag=s3, + component_list=component_list, + dilation=dilation, + median_only=median_only, + background_mask=background_mask, + n_threads=n_threads, + resolution=resolution, + ) diff --git a/flamingo_tools/postprocessing/cochlea_mapping.py b/flamingo_tools/postprocessing/cochlea_mapping.py index 628a6be..d1056ec 100644 --- a/flamingo_tools/postprocessing/cochlea_mapping.py +++ b/flamingo_tools/postprocessing/cochlea_mapping.py @@ -1,4 +1,5 @@ import math +import os from typing import List, Optional, Tuple import networkx as nx @@ -8,6 +9,7 @@ from scipy.interpolate import interp1d from flamingo_tools.postprocessing.label_components import downscaled_centroids +from flamingo_tools.s3_utils import get_s3_path def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]: @@ -750,8 +752,8 @@ def tonotopic_mapping( apex_higher: bool = True, otof: bool = False, ) -> pd.DataFrame: - """Tonotopic mapping of IHCs by supplying a table with component labels. - The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea. + """Tonotopic mapping of SGNs or IHCs by supplying a table with component labels. + The mapping assigns a tonotopic label to each instance according to the position along the length of the cochlea. Args: table: Dataframe of segmentation table. @@ -816,3 +818,74 @@ def tonotopic_mapping( table = map_frequency(table, animal=animal, otof=otof) return table + + +def tonotopic_mapping_single( + table_path: str, + out_path: str, + force_overwrite: bool = False, + cell_type: str = "sgn", + animal: str = "mouse", + otof: bool = False, + apex_position: str = "apex_higher", + component_list: List[int] = [1], + component_mapping: Optional[List[int]] = None, + max_edge_distance: float = 30, + s3: bool = False, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + **_ +): + """Tonotopic mapping of a single cochlea. + Each segmentation instance within a given component list is assigned a frequency[kHz], a run length and an offset. + The components used for the mapping itself can be a subset of the component list to adapt to broken components + along the Rosenthal's canal. + If the cochlea is broken in the direction of the Rosenthal's canal, the components have to be provided in a + continuous order which reflects the positioning within 3D. + The frequency is calculated using the Greenwood function using animal specific parameters. + The orientation of the mapping can be reversed using the apex position in reference to the y-coordinate. + + Args: + table_path: File path to segmentation table. + out_path: Output path to segmentation table with new column "component_labels". + force_overwrite: Forcefully overwrite existing output path. + cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". + animal: Animal for species specific frequency mapping. Either "mouse" or "gerbil". + otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. + apex_position: Identify position of apex and base. Apex is set to node with higher y-value per default. + component_list: List of components. Can be passed to obtain the number of instances within the component list. + components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal. + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. + s3: Use S3 bucket. + s3_credentials: + s3_bucket_name: + s3_service_endpoint: + """ + if os.path.isdir(out_path): + raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") + + if s3: + tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + with fs.open(tsv_path, "r") as f: + table = pd.read_csv(f, sep="\t") + else: + table = pd.read_csv(table_path, sep="\t") + + apex_higher = (apex_position == "apex_higher") + + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True + + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") + + else: + table = tonotopic_mapping(table, component_label=component_list, animal=animal, + cell_type=cell_type, component_mapping=component_mapping, + apex_higher=apex_higher, max_edge_distance=max_edge_distance, + otof=otof) + + table.to_csv(out_path, sep="\t", index=False) diff --git a/flamingo_tools/postprocessing/label_components.py b/flamingo_tools/postprocessing/label_components.py index 56ed204..21bc070 100644 --- a/flamingo_tools/postprocessing/label_components.py +++ b/flamingo_tools/postprocessing/label_components.py @@ -1,5 +1,6 @@ import math import multiprocessing as mp +import os from concurrent import futures from typing import Callable, List, Optional, Tuple @@ -10,6 +11,7 @@ import pandas as pd from elf.io import open_file +from flamingo_tools.s3_utils import get_s3_path from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing from scipy.sparse import csr_matrix from scipy.spatial import distance @@ -673,3 +675,148 @@ def filter_cochlea_volume( combined_dilated[combined_dilated > 0] = 1 return combined_dilated + + +def label_custom_components(tsv_table, custom_dict): + """Label IHC components using multiple post-processing configurations and combine the + results into final components. + The function applies successive post-processing steps defined in a `custom_dic` + configuration. Each entry under `label_dicts` specifies: + - `label_params`: a list of parameter sets. The segmentation is processed once for + each parameter set (e.g., {"min_size": 500, "max_edge_distance": 65, "min_component_length": 5}). + - `components`: lists of label IDs to extract from each corresponding post-processing run. + Label IDs collected from all runs are merged to form the final component (e.g., key "1"). + Global filtering is applied using `min_size_global`, and any `missing_ids` + (e.g., 4800 or 4832) are added explicitly to the final component. + Example `custom_dic` structure: + { + "min_size_global": 500, + "missing_ids": [4800, 4832], + "label_dicts": { + "1": { + "label_params": [ + {"min_size": 500, "max_edge_distance": 65, "min_component_length": 5}, + {"min_size": 400, "max_edge_distance": 45, "min_component_length": 5} + ], + "components": [[18, 22], [1, 45, 83]] + } + } + } + + Args: + tsv_table: Pandas dataframe of the MoBIE segmentation table. + custom_dict: Custom dictionary featuring post-processing parameters. + + Returns: + Pandas dataframe featuring labeled components. + """ + min_size = custom_dict["min_size_global"] + component_labels = [0 for _ in range(len(tsv_table))] + tsv_table.loc[:, "component_labels"] = component_labels + for custom_comp, label_dict in custom_dict["label_dicts"].items(): + label_params = label_dict["label_params"] + label_components = label_dict["components"] + + combined_label_ids = [] + for comp, other_kwargs in zip(label_components, label_params): + tsv_table_tmp = label_components_ihc(tsv_table.copy(), **other_kwargs) + label_ids = list(tsv_table_tmp.loc[tsv_table_tmp["component_labels"].isin(comp), "label_id"]) + combined_label_ids.extend(label_ids) + print(f"{comp}", len(combined_label_ids)) + + combined_label_ids = list(set(combined_label_ids)) + + tsv_table.loc[tsv_table["label_id"].isin(combined_label_ids), "component_labels"] = int(custom_comp) + + tsv_table.loc[tsv_table["n_pixels"] < min_size, "component_labels"] = 0 + if "missing_ids" in list(custom_dict.keys()): + for m in custom_dict["missing_ids"]: + tsv_table.loc[tsv_table["label_id"] == m, "component_labels"] = 1 + + return tsv_table + + +def label_components_single( + table_path: str, + out_path: str, + force_overwrite: bool = False, + cell_type: str = "sgn", + component_list: List[int] = [1], + max_edge_distance: float = 30, + min_component_length: int = 50, + min_size: int = 1000, + s3: bool = False, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + custom_dic: Optional[dict] = None, + **_ +): + """Process a single cochlea using one set of parameters or a custom dictionary. + The cochlea is analyzed using graph-connected components + to label segmentation instances that are closer than a given maximal edge distance. + This process acts on an input segmentation table to which a "component_labels" column is added. + Each entry in this column refers to the index of a connected component. + The largest connected component has an index of 1; the others follow in decreasing order. + + Args: + table_path: File path to segmentation table. + out_path: Output path to segmentation table with new column "component_labels". + force_overwrite: Forcefully overwrite existing output path. + cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". + component_list: List of components. Can be passed to obtain the number of instances within the component list. + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. + min_component_length: Minimal length of nodes of connected component. Filtered out if lower. + min_size: Minimal number of pixels for filtering small instances. + s3: Use S3 bucket. + s3_credentials: + s3_bucket_name: + s3_service_endpoint: + custom_dic: Custom dictionary which allows multiple post-processing configurations and combines the + results into final components. + """ + if os.path.isdir(out_path): + raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") + + if s3: + tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + with fs.open(tsv_path, "r") as f: + table = pd.read_csv(f, sep="\t") + else: + table = pd.read_csv(table_path, sep="\t") + + # overwrite input file + if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: + force_overwrite = True + + if os.path.isfile(out_path) and not force_overwrite: + print(f"Skipping {out_path}. Table already exists.") + + else: + if custom_dic is not None: + # use multiple post-processing configurations + tsv_table = label_custom_components(table, custom_dic) + else: + if cell_type == "sgn": + tsv_table = label_components_sgn(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + elif cell_type == "ihc": + tsv_table = label_components_ihc(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + else: + raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") + + custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) + print(f"Total {cell_type.upper()}s: {len(tsv_table)}") + if component_list == [1]: + print(f"Largest component has {custom_comp} {cell_type.upper()}s.") + else: + for comp in component_list: + num_instances = len(tsv_table[tsv_table["component_labels"] == comp]) + print(f"Component {comp} has {num_instances} instances.") + print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") + + tsv_table.to_csv(out_path, sep="\t", index=False) diff --git a/flamingo_tools/postprocessing/synapse_per_ihc_utils.py b/flamingo_tools/postprocessing/synapse_per_ihc_utils.py index f301485..336a6ff 100644 --- a/flamingo_tools/postprocessing/synapse_per_ihc_utils.py +++ b/flamingo_tools/postprocessing/synapse_per_ihc_utils.py @@ -47,4 +47,4 @@ "component_list": [2, 1, 3]}, "M_AMD_N97_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b", "component_list": [2, 5]}, -} \ No newline at end of file +} diff --git a/reproducibility/label_components/repro_label_components.py b/reproducibility/label_components/repro_label_components.py index 602732a..66d6193 100644 --- a/reproducibility/label_components/repro_label_components.py +++ b/reproducibility/label_components/repro_label_components.py @@ -3,68 +3,7 @@ import os from typing import List, Optional -import pandas as pd -from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.postprocessing.label_components import label_components_sgn, label_components_ihc - - -def label_custom_components(tsv_table, custom_dict): - """Label IHC components using multiple post-processing configurations and combine the - results into final components. - The function applies successive post-processing steps defined in a `custom_dic` - configuration. Each entry under `label_dicts` specifies: - - `label_params`: a list of parameter sets. The segmentation is processed once for - each parameter set (e.g., {"min_size": 500, "max_edge_distance": 65, "min_component_length": 5}). - - `components`: lists of label IDs to extract from each corresponding post-processing run. - Label IDs collected from all runs are merged to form the final component (e.g., key "1"). - Global filtering is applied using `min_size_global`, and any `missing_ids` - (e.g., 4800 or 4832) are added explicitly to the final component. - Example `custom_dic` structure: - { - "min_size_global": 500, - "missing_ids": [4800, 4832], - "label_dicts": { - "1": { - "label_params": [ - {"min_size": 500, "max_edge_distance": 65, "min_component_length": 5}, - {"min_size": 400, "max_edge_distance": 45, "min_component_length": 5} - ], - "components": [[18, 22], [1, 45, 83]] - } - } - } - - Args: - tsv_table: Pandas dataframe of the MoBIE segmentation table. - custom_dict: Custom dictionary featuring post-processing parameters. - - Returns: - Pandas dataframe featuring labeled components. - """ - min_size = custom_dict["min_size_global"] - component_labels = [0 for _ in range(len(tsv_table))] - tsv_table.loc[:, "component_labels"] = component_labels - for custom_comp, label_dict in custom_dict["label_dicts"].items(): - label_params = label_dict["label_params"] - label_components = label_dict["components"] - - combined_label_ids = [] - for comp, other_kwargs in zip(label_components, label_params): - tsv_table_tmp = label_components_ihc(tsv_table.copy(), **other_kwargs) - label_ids = list(tsv_table_tmp.loc[tsv_table_tmp["component_labels"].isin(comp), "label_id"]) - combined_label_ids.extend(label_ids) - print(f"{comp}", len(combined_label_ids)) - - combined_label_ids = list(set(combined_label_ids)) - - tsv_table.loc[tsv_table["label_id"].isin(combined_label_ids), "component_labels"] = int(custom_comp) - - tsv_table.loc[tsv_table["n_pixels"] < min_size, "component_labels"] = 0 - if "missing_ids" in list(custom_dict.keys()): - for m in custom_dict["missing_ids"]: - tsv_table.loc[tsv_table["label_id"] == m, "component_labels"] = 1 - - return tsv_table +from flamingo_tools.postprocessing.label_components import label_components_single def _load_json_as_list(ddict_path: str) -> List[dict]: @@ -74,92 +13,6 @@ def _load_json_as_list(ddict_path: str) -> List[dict]: return data if isinstance(data, list) else [data] -def label_components_single( - table_path: str, - out_path: str, - force_overwrite: bool = False, - cell_type: str = "sgn", - component_list: List[int] = [1], - max_edge_distance: float = 30, - min_component_length: int = 50, - min_size: int = 1000, - s3: bool = False, - s3_credentials: Optional[str] = None, - s3_bucket_name: Optional[str] = None, - s3_service_endpoint: Optional[str] = None, - custom_dic: Optional[dict] = None, - **_ -): - """Process a single cochlea using one set of parameters or a custom dictionary. - The cochlea is analyzed using graph-connected components - to label segmentation instances that are closer than a given maximal edge distance. - This process acts on an input segmentation table to which a "component_labels" column is added. - Each entry in this column refers to the index of a connected component. - The largest connected component has an index of 1; the others follow in decreasing order. - - Args: - table_path: File path to segmentation table. - out_path: Output path to segmentation table with new column "component_labels". - force_overwrite: Forcefully overwrite existing output path. - cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". - component_list: List of components. Can be passed to obtain the number of instances within the component list. - max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. - min_component_length: Minimal length of nodes of connected component. Filtered out if lower. - min_size: Minimal number of pixels for filtering small instances. - s3: Use S3 bucket. - s3_credentials: - s3_bucket_name: - s3_service_endpoint: - custom_dic: Custom dictionary which allows multiple post-processing configurations and combines the - results into final components. - """ - if os.path.isdir(out_path): - raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") - - if s3: - tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, "r") as f: - table = pd.read_csv(f, sep="\t") - else: - table = pd.read_csv(table_path, sep="\t") - - # overwrite input file - if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: - force_overwrite = True - - if os.path.isfile(out_path) and not force_overwrite: - print(f"Skipping {out_path}. Table already exists.") - - else: - if custom_dic is not None: - # use multiple post-processing configurations - tsv_table = label_custom_components(table, custom_dic) - else: - if cell_type == "sgn": - tsv_table = label_components_sgn(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) - elif cell_type == "ihc": - tsv_table = label_components_ihc(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) - else: - raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") - - custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) - print(f"Total {cell_type.upper()}s: {len(tsv_table)}") - if component_list == [1]: - print(f"Largest component has {custom_comp} {cell_type.upper()}s.") - else: - for comp in component_list: - num_instances = len(tsv_table[tsv_table["component_labels"] == comp]) - print(f"Component {comp} has {num_instances} instances.") - print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.") - - tsv_table.to_csv(out_path, sep="\t", index=False) - - def wrapper_label_components( output_path: str, table_path: Optional[str] = None, diff --git a/reproducibility/object_measures/repro_object_measures.py b/reproducibility/object_measures/repro_object_measures.py index 26b8169..c9ee141 100644 --- a/reproducibility/object_measures/repro_object_measures.py +++ b/reproducibility/object_measures/repro_object_measures.py @@ -1,13 +1,10 @@ import argparse import json import os -from multiprocessing import cpu_count from typing import List, Optional -import numpy as np -import flamingo_tools.s3_utils as s3_utils from flamingo_tools.s3_utils import MOBIE_FOLDER -from flamingo_tools.measurements import compute_object_measures, compute_sgn_background_mask +from flamingo_tools.measurements import object_measures_single def _load_json_as_list(ddict_path: str) -> List[dict]: @@ -17,88 +14,6 @@ def _load_json_as_list(ddict_path: str) -> List[dict]: return data if isinstance(data, list) else [data] -def object_measures_single( - table_path: str, - seg_path: str, - image_paths: List[str], - out_paths: List[str], - force_overwrite: bool = False, - component_list: List[int] = [1], - background_mask: Optional[np.typing.ArrayLike] = None, - resolution: List[float] = [0.38, 0.38, 0.38], - s3: bool = False, - s3_credentials: Optional[str] = None, - s3_bucket_name: Optional[str] = None, - s3_service_endpoint: Optional[str] = None, - **_ -): - input_key = "s0" - - if not isinstance(resolution, float): - if len(resolution) == 1: - resolution = resolution * 3 - assert len(resolution) == 3 - resolution = np.array(resolution)[::-1] - else: - resolution = (resolution,) * 3 - - for (img_path, out_path) in zip(image_paths, out_paths): - n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count())) - - # overwrite input file - if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: - force_overwrite = True - - if os.path.isfile(out_path) and not force_overwrite: - print(f"Skipping {out_path}. Table already exists.") - - else: - if background_mask is None: - feature_set = "default" - dilation = None - median_only = False - else: - print("Using background mask for calculating object measures.") - feature_set = "default_background_subtract" - dilation = 4 - median_only = True - - if s3: - img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, - credential_file=s3_credentials) - seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, - credential_file=s3_credentials) - - mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr") - background_mask = compute_sgn_background_mask( - image_path=img_path, - segmentation_path=seg_path, - image_key=input_key, - segmentation_key=input_key, - n_threads=n_threads, - cache_path=mask_cache_path, - ) - - compute_object_measures( - image_path=img_path, - segmentation_path=seg_path, - segmentation_table_path=table_path, - output_table_path=out_path, - image_key=input_key, - segmentation_key=input_key, - feature_set=feature_set, - s3_flag=s3, - component_list=component_list, - dilation=dilation, - median_only=median_only, - background_mask=background_mask, - n_threads=n_threads, - resolution=resolution, - ) - - def wrapper_object_measures( out_paths: List[str], image_paths: Optional[List[str]] = None, @@ -122,12 +37,12 @@ def wrapper_object_measures( force_overwrite: Forcefully overwrite existing output path. s3: Use S3 bucket. """ - out_paths = [os.path.realpath(o) for o in out_paths] if ddict is None: object_measures_single(table_path, seg_path, image_paths, out_paths, force_overwrite=force_overwrite, s3=s3, **kwargs) else: + out_paths = [os.path.realpath(o) for o in out_paths] param_dicts = _load_json_as_list(ddict) for num, params in enumerate(param_dicts): cochlea = params["cochlea"] diff --git a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py index 0f39a5e..335e0ff 100644 --- a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py +++ b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py @@ -3,9 +3,7 @@ import os from typing import List, Optional -import pandas as pd -from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping +from flamingo_tools.postprocessing.cochlea_mapping import tonotopic_mapping_single def _load_json_as_list(ddict_path: str) -> List[dict]: @@ -15,77 +13,6 @@ def _load_json_as_list(ddict_path: str) -> List[dict]: return data if isinstance(data, list) else [data] -def tonotopic_mapping_single( - table_path: str, - out_path: str, - force_overwrite: bool = False, - cell_type: str = "sgn", - animal: str = "mouse", - otof: bool = False, - apex_position: str = "apex_higher", - component_list: List[int] = [1], - component_mapping: Optional[List[int]] = None, - max_edge_distance: float = 30, - s3: bool = False, - s3_credentials: Optional[str] = None, - s3_bucket_name: Optional[str] = None, - s3_service_endpoint: Optional[str] = None, - **_ -): - """Tonotopic mapping of a single cochlea. - Each segmentation instance within a given component list is assigned a frequency[kHz], a run length and an offset. - The components used for the mapping itself can be a subset of the component list to adapt to broken components - along the Rosenthal's canal. - If the cochlea is broken in the direction of the Rosenthal's canal, the components have to be provided in a - continuous order which reflects the positioning within 3D. - The frequency is calculated using the Greenwood function using animal specific parameters. - The orientation of the mapping can be reversed using the apex position in reference to the y-coordinate. - - Args: - table_path: File path to segmentation table. - out_path: Output path to segmentation table with new column "component_labels". - force_overwrite: Forcefully overwrite existing output path. - cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc". - animal: Animal for species specific frequency mapping. Either "mouse" or "gerbil". - otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae. - apex_position: Identify position of apex and base. Apex is set to node with higher y-value per default. - component_list: List of components. Can be passed to obtain the number of instances within the component list. - components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal. - max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. - s3: Use S3 bucket. - s3_credentials: - s3_bucket_name: - s3_service_endpoint: - """ - if os.path.isdir(out_path): - raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.") - - if s3: - tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, "r") as f: - table = pd.read_csv(f, sep="\t") - else: - table = pd.read_csv(table_path, sep="\t") - - apex_higher = (apex_position == "apex_higher") - - # overwrite input file - if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3: - force_overwrite = True - - if os.path.isfile(out_path) and not force_overwrite: - print(f"Skipping {out_path}. Table already exists.") - - else: - table = tonotopic_mapping(table, component_label=component_list, animal=animal, - cell_type=cell_type, component_mapping=component_mapping, - apex_higher=apex_higher, max_edge_distance=max_edge_distance, - otof=otof) - - table.to_csv(out_path, sep="\t", index=False) - - def wrapper_tonotopic_mapping( output_path: str, table_path: Optional[str] = None, From 6c6fa6c719c962855114081bb1aef625348e4983 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 26 Nov 2025 12:01:24 +0100 Subject: [PATCH 5/5] CLI for post-processing functions --- flamingo_tools/postprocessing/cli.py | 149 +++++++++++++++++++++++++++ setup.py | 5 +- 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 flamingo_tools/postprocessing/cli.py diff --git a/flamingo_tools/postprocessing/cli.py b/flamingo_tools/postprocessing/cli.py new file mode 100644 index 0000000..bd36eb5 --- /dev/null +++ b/flamingo_tools/postprocessing/cli.py @@ -0,0 +1,149 @@ +"""private +""" +import argparse + +from .label_components import label_components_single +from .cochlea_mapping import tonotopic_mapping_single +from flamingo_tools.measurements import object_measures_single + + +def label_components(): + parser = argparse.ArgumentParser( + description="Script to label segmentation using a segmentation table and graph connected components.") + + parser.add_argument("-i", "--input", type=str, required=True, help="Input path to segmentation table.") + parser.add_argument("-o", "--output", type=str, required=True, + help="Output path. Either directory (for --json) or specific file otherwise.") + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + + # options for post-processing + parser.add_argument("--cell_type", type=str, default="sgn", + help="Cell type of segmentation. Either 'sgn' or 'ihc'.") + parser.add_argument("--min_size", type=int, default=1000, + help="Minimal number of pixels for filtering small instances.") + parser.add_argument("--min_component_length", type=int, default=50, + help="Minimal length for filtering out connected components.") + parser.add_argument("--max_edge_distance", type=float, default=30, + help="Maximal distance in micrometer between points to create edges for connected components.") + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of connected components.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + label_components_single( + table_path=args.input, + output_path=args.output, + cell_type=args.cell_type, + component_list=args.components, + max_edge_distance=args.max_edge_distance, + min_component_length=args.min_component_length, + min_size=args.min_size, + force_overwrite=args.force, + s3=args.s3, + s3_credentials=args.s3_credentials, + s3_bucket_name=args.s3_bucket_name, + s3_service_endpoint=args.s3_service_endpoint, + ) + + +def tonotopic_mapping(): + parser = argparse.ArgumentParser( + description="Script to extract region of interest (ROI) block around center coordinate.") + + parser.add_argument("-i", "--input", type=str, required=True, help="Input path to segmentation table.") + parser.add_argument("-o", "--output", type=str, required=True, + help="Output path. Either directory or specific file.") + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + + # options for tonotopic mapping + parser.add_argument("--animal", type=str, default="mouse", + help="Animyl type to be used for frequency mapping. Either 'mouse' or 'gerbil'.") + parser.add_argument("--otof", action="store_true", help="Use frequency mapping for OTOF cochleae.") + parser.add_argument("--apex_position", type=str, default="apex_higher", + help="Use frequency mapping for OTOF cochleae.") + + # options for post-processing + parser.add_argument("--cell_type", type=str, default="sgn", + help="Cell type of segmentation. Either 'sgn' or 'ihc'.") + parser.add_argument("--max_edge_distance", type=float, default=30, + help="Maximal distance in micrometer between points to create edges for connected components.") + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of connected components.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + tonotopic_mapping_single( + table_path=args.input, + output_path=args.output, + force_overwrite=args.force, + animal=args.animal, + otof=args.otof, + apex_position=args.apex_position, + cell_type=args.cell_type, + max_edge_distance=args.max_edge_distance, + component_list=args.components, + s3=args.s3, + s3_credentials=args.s3_credentials, + s3_bucket_name=args.s3_bucket_name, + s3_service_endpoint=args.s3_service_endpoint, + ) + + +def object_measures(): + parser = argparse.ArgumentParser( + description="Script to compute object measures for different stainings.") + + parser.add_argument("-o", "--output", type=str, nargs="+", required=True, + help="Output path(s). Either directory or specific file(s).") + parser.add_argument("-i", "--image_paths", type=str, nargs="+", default=None, + help="Input path to one or multiple image channels in ome.zarr format.") + parser.add_argument("-t", "--seg_table", type=str, default=None, + help="Input path to segmentation table.") + parser.add_argument("-s", "--seg_path", type=str, default=None, + help="Input path to segmentation channel in ome.zarr format.") + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + + # options for object measures + parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of components.") + parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[0.38, 0.38, 0.38], + help="Resolution of input in micrometer.") + parser.add_argument("--bg_mask", action="store_true", help="Use background mask for calculating object measures.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + object_measures_single( + out_paths=args.output, + image_paths=args.image_paths, + table_path=args.seg_table, + seg_path=args.seg_path, + force_overwrite=args.force, + s3=args.s3, + ) diff --git a/setup.py b/setup.py index 82d4222..4c3805f 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,10 @@ "flamingo_tools.convert_data = flamingo_tools.data_conversion:convert_lightsheet_to_bdv_cli", "flamingo_tools.run_segmentation = flamingo_tools.segmentation.cli:run_segmentation", "flamingo_tools.run_detection = flamingo_tools.segmentation.cli:run_detection", - # TODO: MoBIE conversion, tonotopic mapping + "flamingo_tools.label_components = flamingo_tools.postprocessing.cli:label_components", + "flamingo_tools.tonotopic_mapping = flamingo_tools.postprocessing.cli:tonotopic_mapping", + "flamingo_tools.object_measures = flamingo_tools.postprocessing.cli:object_measures", + # TODO: MoBIE conversion ], "napari.manifest": [ "cochlea_net = flamingo_tools:napari.yaml",