diff --git a/docs/pretrained.rst b/docs/pretrained.rst index 310fc83bc..1a2a53faf 100644 --- a/docs/pretrained.rst +++ b/docs/pretrained.rst @@ -353,7 +353,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -369,7 +369,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -393,7 +393,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -409,7 +409,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) diff --git a/tests/engines/test_nucleus_detection_engine.py b/tests/engines/test_nucleus_detection_engine.py new file mode 100644 index 000000000..9cda077d1 --- /dev/null +++ b/tests/engines/test_nucleus_detection_engine.py @@ -0,0 +1,225 @@ +"""Tests for NucleusDetector.""" + +import pathlib +import shutil +from collections.abc import Callable + +import dask.array as da +import numpy as np +import pandas as pd +import pytest + +from tiatoolbox.annotation.storage import SQLiteStore +from tiatoolbox.models.engine.nucleus_detector import NucleusDetector +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import imwrite +from tiatoolbox.wsicore.wsireader import WSIReader + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def _rm_dir(path: pathlib.Path) -> None: + """Helper func to remove directory.""" + if pathlib.Path(path).exists(): + shutil.rmtree(path, ignore_errors=True) + + +def check_output(path: pathlib.Path) -> None: + """Check NucleusDetector output.""" + + +def test_nucleus_detection_nms_empty_dataframe() -> None: + """nucleus_detection_nms should return a copy for empty inputs.""" + df = pd.DataFrame(columns=["x", "y", "type", "prob"]) + + result = NucleusDetector.nucleus_detection_nms(df, radius=3) + + assert result.empty + assert result is not df + assert list(result.columns) == ["x", "y", "type", "prob"] + + +def test_nucleus_detection_nms_invalid_radius() -> None: + """Radius must be strictly positive.""" + df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]}) + + with pytest.raises(ValueError, match="radius must be > 0"): + NucleusDetector.nucleus_detection_nms(df, radius=0) + + +def test_nucleus_detection_nms_invalid_overlap_threshold() -> None: + """overlap_threshold must lie in (0, 1].""" + df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]}) + + message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0" + with pytest.raises(ValueError, match=message): + NucleusDetector.nucleus_detection_nms(df, radius=1, overlap_threshold=0) + + +def test_nucleus_detection_nms_suppresses_overlapping_detections() -> None: + """Lower-probability overlapping detections are removed.""" + df = pd.DataFrame( + { + "x": [2, 0, 20], + "y": [1, 0, 20], + "type": [1, 1, 2], + "prob": [0.6, 0.9, 0.7], + } + ) + + result = NucleusDetector.nucleus_detection_nms(df, radius=5) + + expected = pd.DataFrame( + {"x": [0, 20], "y": [0, 20], "type": [1, 2], "prob": [0.9, 0.7]} + ) + pd.testing.assert_frame_equal(result.reset_index(drop=True), expected) + + +def test_nucleus_detection_nms_suppresses_across_types() -> None: + """Overlapping detections of different types are also suppressed.""" + df = pd.DataFrame( + { + "x": [0, 0, 20], + "y": [0, 0, 0], + "type": [1, 2, 1], + "prob": [0.6, 0.95, 0.4], + } + ) + + result = NucleusDetector.nucleus_detection_nms(df, radius=5) + + expected = pd.DataFrame( + {"x": [0, 20], "y": [0, 0], "type": [2, 1], "prob": [0.95, 0.4]} + ) + pd.testing.assert_frame_equal(result.reset_index(drop=True), expected) + + +def test_nucleus_detection_nms_retains_non_overlapping_candidates() -> None: + """Detections with IoU below the threshold are preserved.""" + df = pd.DataFrame( + { + "x": [0, 10], + "y": [0, 0], + "type": [1, 1], + "prob": [0.8, 0.5], + } + ) + + result = NucleusDetector.nucleus_detection_nms(df, radius=5, overlap_threshold=0.5) + + expected = pd.DataFrame( + {"x": [0, 10], "y": [0, 0], "type": [1, 1], "prob": [0.8, 0.5]} + ) + pd.testing.assert_frame_equal(result.reset_index(drop=True), expected) + + +def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None: + """Test for nucleus detection engine.""" + mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs")) + + pretrained_model = "mapde-conic" + + save_dir = tmp_path + + nucleus_detector = NucleusDetector(model=pretrained_model) + _ = nucleus_detector.run( + patch_mode=False, + device=device, + output_type="annotationstore", + memory_threshold=50, + images=[mini_wsi_svs], + save_dir=save_dir, + overwrite=True, + ) + + store = SQLiteStore.open(save_dir / "wsi4_512_512.db") + assert len(store.values()) == 281 + store.close() + + _rm_dir(save_dir) + + +def test_nucleus_detector_patch( + remote_sample: Callable, tmp_path: pathlib.Path +) -> None: + """Test for nucleus detection engine in patch mode.""" + mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs")) + + wsi_reader = WSIReader.open(mini_wsi_svs) + patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp") + patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp") + + pretrained_model = "mapde-conic" + + save_dir = tmp_path + + nucleus_detector = NucleusDetector(model=pretrained_model) + _ = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="annotationstore", + memory_threshold=50, + images=[patch_1, patch_2], + save_dir=save_dir, + overwrite=True, + class_dict=None, + ) + + store_1 = SQLiteStore.open(save_dir / "0.db") + assert len(store_1.values()) == 270 + store_1.close() + + store_2 = SQLiteStore.open(save_dir / "1.db") + assert len(store_2.values()) == 52 + store_2.close() + + imwrite(save_dir / "patch_0.png", patch_1) + imwrite(save_dir / "patch_1.png", patch_2) + _ = nucleus_detector.run( + patch_mode=True, + device=device, + output_type="zarr", + memory_threshold=50, + images=[save_dir / "patch_0.png", save_dir / "patch_1.png"], + save_dir=save_dir, + overwrite=True, + ) + + store_1 = SQLiteStore.open(save_dir / "patch_0.db") + assert len(store_1.values()) == 270 + store_1.close() + + store_2 = SQLiteStore.open(save_dir / "patch_1.db") + assert len(store_2.values()) == 52 + store_2.close() + + _rm_dir(save_dir) + + +def test_nucleus_detector_write_centroid_maps(tmp_path: pathlib.Path) -> None: + """Test for _write_centroid_maps function.""" + detection_maps = np.zeros((20, 20, 1), dtype=np.uint8) + detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1)) + + store = NucleusDetector.write_centroid_maps_to_store( + detection_maps=detection_maps, class_dict=None + ) + assert len(store.values()) == 0 + store.close() + + detection_maps = np.zeros((20, 20, 1), dtype=np.uint8) + detection_maps[10, 10, 0] = 1 + detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1)) + _ = NucleusDetector.write_centroid_maps_to_store( + detection_maps=detection_maps, + save_path=tmp_path / "test.db", + class_dict={0: "nucleus"}, + ) + store = SQLiteStore.open(tmp_path / "test.db") + assert len(store.values()) == 1 + annotation = next(iter(store.values())) + print(annotation) + assert annotation.properties["type"] == "nucleus" + assert annotation.geometry.centroid.x == 10.0 + assert annotation.geometry.centroid.y == 10.0 + store.close() diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 19163f593..22a354ad4 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -8,6 +8,7 @@ from tiatoolbox.models import MapDe from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.models.engine.nucleus_detector import NucleusDetector from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -48,7 +49,35 @@ def test_functionality(remote_sample: Callable) -> None: batch = torch.from_numpy(patch)[None] output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) - assert np.all(output[0:2] == [[19, 171], [53, 89]]) + xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None) + + np.testing.assert_array_equal(xs[0:2], np.array([242, 192])) + np.testing.assert_array_equal(ys[0:2], np.array([10, 13])) + + patch = reader.read_bounds( + (0, 0, 252, 252), + resolution=0.50, + units="mpp", + coord_space="resolution", + ) + + model, weights_path = _load_mapde(name="mapde-conic") + patch = model.preproc(patch) + batch = torch.from_numpy(patch)[None] + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + block_info = { + 0: { + "array-location": [ + [0, 1], + [0, 1], + ], # dummy block to test no valid detections + } + } + output = model.postproc(output[0], block_info=block_info) + xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None) + np.testing.assert_array_equal(xs, np.array([])) + np.testing.assert_array_equal(ys, np.array([])) + Path(weights_path).unlink() diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index a456faff5..4f8364854 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -7,6 +7,7 @@ from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.models.engine.nucleus_detector import NucleusDetector from tiatoolbox.utils import env_detection from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -48,7 +49,10 @@ def test_functionality(remote_sample: Callable) -> None: device=select_device(on_gpu=env_detection.has_gpu()), ) output = model.postproc(output[0]) - np.testing.assert_array_equal(output, np.array([[8, 7]])) + xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None) + + np.testing.assert_array_equal(xs, np.array([8])) + np.testing.assert_array_equal(ys, np.array([7])) model = _load_sccnn(name="sccnn-conic") output = model.infer_batch( @@ -56,5 +60,31 @@ def test_functionality(remote_sample: Callable) -> None: batch, device=select_device(on_gpu=env_detection.has_gpu()), ) - output = model.postproc(output[0]) - np.testing.assert_array_equal(output, np.array([[7, 8]])) + block_info = { + 0: { + "array-location": [[0, 31], [0, 31]], + } + } + output = model.postproc(output[0], block_info=block_info) + xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None) + np.testing.assert_array_equal(xs, np.array([7])) + np.testing.assert_array_equal(ys, np.array([8])) + + model = _load_sccnn(name="sccnn-conic") + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) + block_info = { + 0: { + "array-location": [ + [0, 1], + [0, 1], + ], # dummy block to test no valid detections + } + } + output = model.postproc(output[0], block_info=block_info) + xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None) + np.testing.assert_array_equal(xs, np.array([])) + np.testing.assert_array_equal(ys, np.array([])) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 3a4ccab9b..40f308ea1 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -814,6 +814,10 @@ mapde-crchisto: min_distance: 4 threshold_abs: 250 num_classes: 1 + postproc_tile_shape: [ 2048, 2048 ] + output_class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -821,7 +825,6 @@ mapde-crchisto: - { "units": "mpp", "resolution": 0.5 } output_resolutions: - { "units": "mpp", "resolution": 0.5 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] @@ -836,6 +839,10 @@ mapde-conic: min_distance: 3 threshold_abs: 205 num_classes: 1 + postproc_tile_shape: [ 2048, 2048 ] + output_class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -843,7 +850,6 @@ mapde-conic: - { "units": "mpp", "resolution": 0.5 } output_resolutions: - { "units": "mpp", "resolution": 0.5 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] @@ -859,6 +865,10 @@ sccnn-crchisto: min_distance: 6 threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] + postproc_tile_shape: [ 2048, 2048 ] + output_class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -866,7 +876,6 @@ sccnn-crchisto: - { "units": "mpp", "resolution": 0.25 } output_resolutions: - { "units": "mpp", "resolution": 0.25 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] @@ -882,6 +891,10 @@ sccnn-conic: min_distance: 5 threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] + postproc_tile_shape: [ 2048, 2048 ] + output_class_dict: { + 0: "nucleus" + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -889,7 +902,6 @@ sccnn-conic: - { "units": "mpp", "resolution": 0.25 } output_resolutions: - { "units": "mpp", "resolution": 0.25 } - tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index 0900aa6fd..a623bc160 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -78,6 +78,8 @@ def __init__( min_distance: int = 4, threshold_abs: float = 250, num_classes: int = 1, + postproc_tile_shape: tuple[int, int] = (2048, 2048), + output_class_dict: dict[int, str] | None = None, ) -> None: """Initialize :class:`MapDe`.""" super().__init__( @@ -85,6 +87,8 @@ def __init__( num_input_channels=num_input_channels, out_activation="relu", ) + self.output_class_dict = output_class_dict + self.postproc_tile_shape = postproc_tile_shape dist_filter = np.array( [ @@ -233,28 +237,71 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor: return F.relu(out) # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(self: MapDe, prediction_map: np.ndarray) -> np.ndarray: - """Post-processing script for MicroNet. + def postproc( + self: MapDe, + block: np.ndarray, + block_info: dict | None = None, + depth_h: int = 0, + depth_w: int = 0, + ) -> np.ndarray: + """MapDe post-processing function. + + Builds a processed mask per input channel, runs peak_local_max then + writes 1.0 at peak pixels. - Performs peak detection and extracts coordinates in x, y format. + Can be called inside Dask.da.map_overlap on a padded NumPy block: + (h_pad, w_pad, C) to process large prediction maps in chunks. + Keeps only centroids whose (row,col) lie in the interior window: + rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w) + + Returns same spatial shape as the input block Args: - prediction_map (ndarray): - Input image of type numpy array. + block: NumPy array (H, W, C). + block_info: Dask block info dict. + Only used when called inside dask.array.map_overlap. + depth_h: Halo size in pixels for height (rows). + Only used when it's called inside dask.array.map_overlap. + depth_w: Halo size in pixels for width (cols). + Only used when it's called inside dask.array.map_overlap. Returns: - :class:`numpy.ndarray`: - Pixel-wise nuclear instance segmentation - prediction. - + out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. """ - coordinates = peak_local_max( - np.squeeze(prediction_map[0], axis=2), - min_distance=self.min_distance, - threshold_abs=self.threshold_abs, - exclude_border=False, - ) - return np.fliplr(coordinates) + block_height, block_width, block_channels = block.shape + + # --- derive core (pre-overlap) size for THIS block --- + if block_info is None: + core_h = block_height - 2 * depth_h + core_w = block_width - 2 * depth_w + else: + info = block_info[0] + locs = info[ + "array-location" + ] # a list of (start, stop) coordinates per axis + core_h = int(locs[0][1] - locs[0][0]) # r1 - r0 + core_w = int(locs[1][1] - locs[1][0]) + + rmin, rmax = depth_h, depth_h + core_h + cmin, cmax = depth_w, depth_w + core_w + + out = np.zeros((block_height, block_width, block_channels), dtype=np.float32) + + for ch in range(block_channels): + img = np.asarray(block[..., ch]) # NumPy 2D view + + coords = peak_local_max( + img, + min_distance=self.min_distance, + threshold_abs=self.threshold_abs, + exclude_border=False, + ) + + for r, c in coords: + if (rmin <= r < rmax) and (cmin <= c < cmax): + out[r, c, ch] = 1.0 + + return out @staticmethod def infer_batch( @@ -262,7 +309,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list[np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. This contains logic for forward operation as well as batch I/O @@ -293,8 +340,4 @@ def infer_batch( pred = model(patch_imgs_gpu) pred = pred.permute(0, 2, 3, 1).contiguous() - pred = pred.cpu().numpy() - - return [ - pred, - ] + return pred.cpu().numpy() diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index 2c47f9d12..0f4fc945a 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -91,6 +91,8 @@ def __init__( radius: int = 12, min_distance: int = 6, threshold_abs: float = 0.20, + postproc_tile_shape: tuple[int, int] = (2048, 2048), + output_class_dict: dict[int, str] | None = None, ) -> None: """Initialize :class:`SCCNN`.""" super().__init__() @@ -99,6 +101,8 @@ def __init__( self.in_ch = num_input_channels self.out_height = out_height self.out_width = out_width + self.postproc_tile_shape = postproc_tile_shape + self.output_class_dict = output_class_dict # Create mesh grid and convert to 3D vector x, y = torch.meshgrid( @@ -325,35 +329,79 @@ def spatially_constrained_layer1( return self.spatially_constrained_layer2(s1_sigmoid0, s1_sigmoid1, s1_sigmoid2) # skipcq: PYL-W0221 # noqa: ERA001 - def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray: - """Post-processing script for MicroNet. + def postproc( + self: SCCNN, + block: np.ndarray, + block_info: dict | None = None, + depth_h: int = 0, + depth_w: int = 0, + ) -> np.ndarray: + """SCCNN post-processing function. + + Builds a processed mask per input channel, runs peak_local_max then + writes 1.0 at peak pixels. + + Can be called inside Dask.da.map_overlap on a padded NumPy block: + (h_pad, w_pad, C) to process large prediction maps in chunks. + Keeps only centroids whose (row,col) lie in the interior window: + rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w) - Performs peak detection and extracts coordinates in x, y format. + Returns same spatial shape as the input block Args: - prediction_map (ndarray): - Input image of type numpy array. + block: NumPy array (H, W, C). + block_info: Dask block info dict. Only used when called inside + dask.array.map_overlap. + depth_h: Halo size in pixels for height (rows). + Only used when it's called inside dask.array.map_overlap. + depth_w: Halo size in pixels for width (cols). + Only used when it's called inside dask.array.map_overlap. Returns: - :class:`numpy.ndarray`: - Pixel-wise nuclear instance segmentation - prediction. - + out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. """ - coordinates = peak_local_max( - np.squeeze(prediction_map[0], axis=2), - min_distance=self.min_distance, - threshold_abs=self.threshold_abs, - exclude_border=False, - ) - return np.fliplr(coordinates) + block_height, block_width, block_channels = block.shape + + # --- derive core (pre-overlap) size for THIS block --- + if block_info is None: + core_h = block_height - 2 * depth_h + core_w = block_width - 2 * depth_w + else: + info = block_info[0] + locs = info[ + "array-location" + ] # a list of (start, stop) coordinates per axis + core_h = int(locs[0][1] - locs[0][0]) # r1 - r0 + core_w = int(locs[1][1] - locs[1][0]) + + rmin, rmax = depth_h, depth_h + core_h + cmin, cmax = depth_w, depth_w + core_w + + out = np.zeros((block_height, block_width, block_channels), dtype=np.float32) + + for ch in range(block_channels): + img = np.asarray(block[..., ch]) # NumPy 2D view + + coords = peak_local_max( + img, + min_distance=self.min_distance, + threshold_abs=self.threshold_abs, + exclude_border=False, + ) + + for r, c in coords: + if (rmin <= r < rmax) and (cmin <= c < cmax): + out[r, c, ch] = 1.0 + + return out @staticmethod def infer_batch( model: nn.Module, - batch_data: np.ndarray | torch.Tensor, + batch_data: torch.Tensor, + *, device: str, - ) -> list[np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. This contains logic for forward operation as well as batch I/O @@ -386,8 +434,4 @@ def infer_batch( pred = model(patch_imgs_gpu) pred = pred.permute(0, 2, 3, 1).contiguous() - pred = pred.cpu().numpy() - - return [ - pred, - ] + return pred.cpu().numpy() diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 9c00ac4a2..ff65892a2 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -2,6 +2,7 @@ from . import ( engine_abc, + nucleus_detector, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, @@ -9,6 +10,7 @@ __all__ = [ "engine_abc", + "nucleus_detector", "nucleus_instance_segmentor", "patch_predictor", "semantic_segmentor", diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 73b4ca1c1..0a6a8e127 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -45,7 +45,7 @@ import torch import zarr from dask import compute -from dask.diagnostics import ProgressBar +from dask.diagnostics.progress import ProgressBar from torch import nn from typing_extensions import Unpack diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py new file mode 100644 index 000000000..d1aeba676 --- /dev/null +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -0,0 +1,476 @@ +"""This module implements nucleus detection engine.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import dask +import dask.array as da +import numpy as np +from dask.diagnostics.progress import ProgressBar +from shapely.geometry import Point + +from tiatoolbox import logger +from tiatoolbox.annotation.storage import Annotation, SQLiteStore +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, +) + +if TYPE_CHECKING: # pragma: no cover + from typing import Unpack + + import pandas as pd + + from tiatoolbox.annotation import AnnotationStore + + +class NucleusDetector(SemanticSegmentor): + r"""Nucleus detection engine. + + Args: + model (str or nn.Module): + Defined PyTorch model or name of the existing models support by + tiatoolbox for processing the data e.g., mapde-conic, mapde-crchisto. + For a full list of pretrained models, please refer to the `docs + `. + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument is case insensitive. + batch_size (int): + Number of images fed into the model each time. + num_workers (int): + Number of workers used in torch.utils.data.DataLoader. + weights (str or pathlib.Path, optional): + Pretrained weights file path or name of the existing weights + supported by tiatoolbox. If ``None``, and `model` is a string, + the default pretrained weights for the specified model will be used. + If `model` is a nn.Module, no weights will be loaded + unless specified here. + device (str): + Device to run the model on, e.g., 'cpu' or 'cuda:0'. + verbose (bool): + Whether to output logging information. + + Supported TIAToolBox Pre-trained Models: + - `mapde-conic` + - `mapde-crchisto` + + + Examples: + >>> model_name = "mapde-conic" + >>> detector = NucleusDetector(model=model_name, batch_size=16, num_workers=8) + >>> detector.run( + ... images=[pathlib.Path("example_wsi.tiff")], + ... patch_mode=False, + ... device="cuda", + ... save_dir=pathlib.Path("output_directory/"), + ... overwrite=True, + ... output_type="annotationstore", + ... class_dict={0: "nucleus"}, + ... auto_get_mask=True, + ... memory_threshold=80 + ... ) + """ + + def post_process_patches( + self: NucleusDetector, + raw_predictions: list[da.Array], + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> list[np.ndarray]: + """Define how to post-process patch predictions. + + Args: + raw_predictions (da.Array): The raw predictions from the model. + prediction_shape (tuple[int, ...]): The shape of the predictions. + prediction_dtype (type): The data type of the predictions. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters + + Returns: + A list of DataFrames containing the post-processed + predictions for each patch. + + """ + _ = kwargs.get("return_probabilities") + _ = prediction_shape + _ = prediction_dtype + + return [ + self.model.postproc_func(raw_predictions[i]) + for i in range(len(raw_predictions)) + ] + + def post_process_wsi( + self: NucleusDetector, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[SemanticSegmentorRunParams], # noqa: ARG002 + ) -> da.Array: + """Define how to post-process WSI predictions. + + Processes the raw prediction dask array using map_overlap + to apply the model's post-processing function on each chunk + with appropriate overlaps on chunk boundaries. + + Args: + raw_predictions (da.Array): The raw predictions from the model. + prediction_shape (tuple[int, ...]): The shape of the predictions. + prediction_dtype (type): The data type of the predictions. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters + + Returns: + Post-processed dask array of detections at the WSI level. + The array has the same shape and dtype as the input. + Each pixel indicates the presence of a detected nucleus + as a probability score. + + """ + logger.info("Post processing WSI predictions in NucleusDetector") + logger.info("Raw probabilities shape: %s", prediction_shape) + logger.info("Raw probabilities dtype %s", prediction_dtype) + logger.info("Raw chunk size: %s", raw_predictions.chunks) + + # Add halo (overlap) around each block for post-processing + depth_h = self.model.min_distance + depth_w = self.model.min_distance + depth = {0: depth_h, 1: depth_w, 2: 0} + + # Re-chunk to post-processing tile shape for more efficient processing + rechunked_prediction_map = raw_predictions.rechunk( + (self.model.postproc_tile_shape[0], self.model.postproc_tile_shape[1], -1) + ) + logger.info("Post-processing tile size: %s", rechunked_prediction_map.chunks) + logger.info("Post-processing tiles overlap: (h=%d, w=%d)", depth_h, depth_w) + + return da.map_overlap( + rechunked_prediction_map, + self.model.postproc, + depth=depth, + boundary=0, + dtype=prediction_dtype, + block_info=True, + depth_h=depth_h, + depth_w=depth_w, + ) + + def save_predictions( + self: NucleusDetector, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path: + """Save nucleus detections to disk or return them in memory. + + This method saves predictions in one of the supported formats: + - "annotationstore": converts predictions to an AnnotationStore (.db file). + + If `patch_mode` is True, predictions are saved per image. If False, + predictions are merged and saved as a single output. + + Args: + processed_predictions (dict): + Dictionary containing processed model predictions. + output_type (str): + "annotationstore". + save_path (Path | None): + Path to save the output file. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters including: + - scale_factor (tuple[float, float]): For coordinate transformation. + - class_dict (dict): Mapping of class indices to names. + + Returns: + AnnotationStore | Path: + - returns AnnotationStore or path to .db file. + + """ + # Only "annotationstore" output type is supported for NucleusDetector + if output_type != "annotationstore": + logger.warning( + "Output type %s is not supported by NucleusDetector. " + "Defaulting to 'annotationstore'.", + output_type, + ) + output_type = "annotationstore" + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + if class_dict is None: + class_dict = self.model.output_class_dict + + # Need to add support for zarr conversion. + save_paths = [] + + logger.info("Saving predictions as AnnotationStore.") + + if self.patch_mode: + save_paths = [] + for i, predictions in enumerate(processed_predictions["predictions"]): + predictions_da = da.from_array(predictions, chunks=predictions.shape) + + if isinstance(self.images[i], Path): + output_path = save_path.parent / (self.images[i].stem + ".db") + else: + output_path = save_path.parent / (str(i) + ".db") + + out_file = self.write_centroid_maps_to_store( + predictions_da, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) + + save_paths.append(out_file) + return save_paths + return self.write_centroid_maps_to_store( + processed_predictions["predictions"], + scale_factor=scale_factor, + save_path=save_path, + class_dict=class_dict, + ) + + @staticmethod + def nucleus_detection_nms( + df: pd.DataFrame, radius: int, overlap_threshold: float = 0.5 + ) -> pd.DataFrame: + """Non-Maximum Suppression across ALL detections. + + Keeps the highest-prob detection, removes any other point + within 'radius' pixels > overlap_threshold. + Expects dataframe columns: ['x','y','type','prob']. + + Args: + df: pandas DataFrame of detections. + radius: radius in pixels for suppression. + overlap_threshold: float in [0,1], fraction of radius for suppression. + + Returns: + filtered DataFrame with same columns/dtypes. + """ + overlap_max = 1.0 + overlap_min = 0.0 + if df.empty: + return df.copy() + if radius <= 0: + msg = "radius must be > 0" + raise ValueError(msg) + if not overlap_min < overlap_threshold <= overlap_max: + msg = f"overlap_threshold must be in (0.0, 1.0], got {overlap_threshold}" + raise ValueError(msg) + + # Sort by descending probability (highest priority first) + sub = df.sort_values("prob", ascending=False).reset_index(drop=True) + + # Coordinates as float64 for distance math + + coords = sub[["x", "y"]].to_numpy(dtype=np.float64) + r = float(radius) + two_r = 2.0 * r + two_r2 = two_r * two_r # distance^2 cutoff for any overlap + + suppressed = np.zeros(len(sub), dtype=bool) + keep_idx = [] + + for i in range(len(sub)): + if suppressed[i]: + continue + + keep_idx.append(i) + + # Vectorised distances to all points + dx = coords[:, 0] - coords[i, 0] + dy = coords[:, 1] - coords[i, 1] + d2 = dx * dx + dy * dy + + # Only points with d < 2r can have nonzero overlap + cand = d2 <= two_r2 + cand[i] = False # don't suppress the kept point itself + if not np.any(cand): + continue + + d = np.sqrt(d2[cand]) + + # Safe cosine argument = (distance รท diameter) + # Clamp for numerical stability + u = np.clip(d / (2.0 * r), -1.0, 1.0) + # Exact intersection area of two equal-radius circles. + inter = 2.0 * (r * r) * np.arccos(u) - 0.5 * d * np.sqrt( + np.clip(4.0 * r * r - d * d, 0.0, None) + ) + + union = 2.0 * np.pi * (r * r) - inter + iou = inter / union + + # Suppress candidates whose IoU exceeds threshold + idx_cand = np.where(cand)[0] + to_suppress = idx_cand[iou >= overlap_threshold] + suppressed[to_suppress] = True + + return sub.iloc[keep_idx].copy() + + @staticmethod + def _centroid_maps_to_detection_records( + block: np.ndarray, block_info: dict | None = None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Convert a block of centroid maps to detection records. + + Each block is a NumPy array of shape (h, w, C) containing detection + probabilities of each class c. This function finds non-zero detections + and returns their global coordinates, class IDs (channel), and probabilities. + + Args: + block: NumPy array (h, w, C) for this chunk. + block_info: Dask block info dict. + + Returns: + Tuple of ([x_coords], [y_coords], [class_ids], [probs]) + """ + # block: (h, w, C) NumPy chunk (post-stitching, no halos) + if block_info is not None: + info = block_info[0] + (r0, _), (c0, _), _ = info["array-location"] # global interior start/stop + else: + r0, c0 = 0, 0 + + # find the coordinates and channel indices of nonzeros + ys, xs, cs = np.nonzero(block) + + if ys.size == 0: + # return empty arrays + return ( + np.empty(0, dtype=np.uint32), + np.empty(0, dtype=np.uint32), + np.empty(0, dtype=np.uint32), + np.empty(0, dtype=np.float32), + ) + + x = xs.astype(np.uint32, copy=False) + int(c0) + y = ys.astype(np.uint32, copy=False) + int(r0) + t = cs.astype(np.uint32, copy=False) + + # read detection probabilities + p = block[ys, xs, cs].astype(np.float32, copy=False) + return (x, y, t, p) + + @staticmethod + def _write_detection_records_to_store( + recs: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + store: SQLiteStore, + scale_factor: tuple[float, float], + class_dict: dict[int, str | int] | None, + batch_size: int = 5000, + ) -> int: + """Write detection records to AnnotationStore in batches. + + Args: + recs: Tuple of ([x_coords], [y_coords], [class_ids], [probs]) + store: SQLiteStore to write the detections to + scale_factor: Scaling factors for x and y coordinates + class_dict: Mapping from original class IDs to new class names + batch_size: Number of records to write in each batch + Returns: + Total number of records written + """ + x, y, t, p = recs + n = len(x) + if n == 0: + return 0 # nothing to write + + # scale coordinates + x = np.rint(x * scale_factor[0]).astype(np.uint32, copy=False) + y = np.rint(y * scale_factor[1]).astype(np.uint32, copy=False) + + # class mapping + if class_dict is None: + # identity over actually-present types + uniq = np.unique(t) + class_dict = {int(k): int(k) for k in uniq} + labels = np.array([class_dict.get(int(k), int(k)) for k in t], dtype=object) + + def make_points(xb: np.ndarray, yb: np.ndarray) -> list[Point]: + """Create Shapely Point geometries from coordinate arrays.""" + return [Point(int(xx), int(yy)) for xx, yy in zip(xb, yb, strict=True)] + + written = 0 + for i in range(0, n, batch_size): + j = min(i + batch_size, n) + pts = make_points(x[i:j], y[i:j]) + + anns = [ + Annotation( + geometry=pt, properties={"type": lbl, "probability": float(pp)} + ) + for pt, lbl, pp in zip(pts, labels[i:j], p[i:j], strict=True) + ] + store.append_many(anns) + written += j - i + return written + + @staticmethod + def write_centroid_maps_to_store( + detection_maps: da.Array, + scale_factor: tuple[float, float] = (1.0, 1.0), + class_dict: dict | None = None, + save_path: Path | None = None, + batch_size: int = 5000, + ) -> Path | SQLiteStore: + """Write post-processed detection maps to an AnnotationStore. + + This is done in chunks using Dask for efficiency and to handle large + detection maps at WSI level. + + Args: + detection_maps: Dask array (H, W, C) of detection scores. + scale_factor: Tuple (sx, sy) to scale coordinates before saving. + class_dict: Optional dict mapping class indices to names. + save_path: Optional Path to save the .db file. + If None, returns in-memory store. + batch_size: Number of records to write per batch. + + Returns: + Path to saved .db file if save_path is provided, else in-memory SQLiteStore. + """ + recs_delayed = ( # Convert each block to detection records + detection_maps.map_blocks( + NucleusDetector._centroid_maps_to_detection_records, + dtype=object, # we return Python tuples + block_info=True, + ) + .to_delayed() + .ravel() + ) + + # create annotation store + store = SQLiteStore() + + # one delayed writer per chunk (returns number of detections written) + writes = [ + dask.delayed(NucleusDetector._write_detection_records_to_store)( + recs, store, scale_factor, class_dict, batch_size + ) + for recs in recs_delayed + ] + + # IMPORTANT: SQLite is single-writer; run sequentially + with ProgressBar(): + total = dask.compute(*writes, scheduler="single-threaded") + logger.info("Total detections written to store: %s", sum(total)) + + # if a save directory is provided, then dump store into a file + if save_path: + save_path.parent.absolute().mkdir(parents=True, exist_ok=True) + save_path = save_path.parent.absolute() / (save_path.stem + ".db") + store.commit() + store.dump(save_path) + return save_path + + return store