Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
767441d
copy code from old PR
Jiaqi-Lv Nov 6, 2025
d2a9702
preliminiary testing
Jiaqi-Lv Nov 7, 2025
44c4994
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
0f8d4fe
initial prototype
Jiaqi-Lv Nov 10, 2025
d42b78a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
b7f829c
clean up
Jiaqi-Lv Nov 10, 2025
cba5fd5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
f2cdcc4
update
Jiaqi-Lv Nov 11, 2025
dd99d97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2025
c468a8b
update pipeline
Jiaqi-Lv Nov 12, 2025
14f870a
update pipeline
Jiaqi-Lv Nov 12, 2025
6e65fba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
7eb916e
refactor code
Jiaqi-Lv Nov 12, 2025
8442ac2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
de83074
clean up
Jiaqi-Lv Nov 12, 2025
17e5422
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Nov 17, 2025
f5b1885
update patch mode processing
Jiaqi-Lv Nov 22, 2025
551e43c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2025
12f985a
tidy up code
Jiaqi-Lv Nov 22, 2025
05b2c7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2025
2afbf8c
fix precommit
Jiaqi-Lv Nov 23, 2025
367295d
update test
Jiaqi-Lv Nov 23, 2025
7912abe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2025
0a72e8b
improve tests
Jiaqi-Lv Nov 24, 2025
f8b4189
improve tests
Jiaqi-Lv Nov 24, 2025
228731c
precommit
Jiaqi-Lv Nov 24, 2025
6c26a0f
fix deepsource
Jiaqi-Lv Nov 25, 2025
7ffea5b
fix deepsource
Jiaqi-Lv Nov 27, 2025
79fc088
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-detecti…
shaneahmed Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/pretrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(31, 31),
stride_shape=(8, 8),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -369,7 +369,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(252, 252),
stride_shape=(150, 150),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -393,7 +393,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(31, 31),
stride_shape=(8, 8),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand All @@ -409,7 +409,7 @@ The input output configuration is as follows:
ioconfig = IOPatchPredictorConfig(
patch_input_shape=(252, 252),
stride_shape=(150, 150),
input_resolutions=[{"resolution": 0.25, "units": "mpp"}]
input_resolutions=[{"resolution": 0.5, "units": "mpp"}]
)


Expand Down
225 changes: 225 additions & 0 deletions tests/engines/test_nucleus_detection_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""Tests for NucleusDetector."""

import pathlib
import shutil
from collections.abc import Callable

import dask.array as da
import numpy as np
import pandas as pd
import pytest

from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import imwrite
from tiatoolbox.wsicore.wsireader import WSIReader

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def _rm_dir(path: pathlib.Path) -> None:
"""Helper func to remove directory."""
if pathlib.Path(path).exists():
shutil.rmtree(path, ignore_errors=True)


def check_output(path: pathlib.Path) -> None:
"""Check NucleusDetector output."""


def test_nucleus_detection_nms_empty_dataframe() -> None:
"""nucleus_detection_nms should return a copy for empty inputs."""
df = pd.DataFrame(columns=["x", "y", "type", "prob"])

result = NucleusDetector.nucleus_detection_nms(df, radius=3)

assert result.empty
assert result is not df
assert list(result.columns) == ["x", "y", "type", "prob"]


def test_nucleus_detection_nms_invalid_radius() -> None:
"""Radius must be strictly positive."""
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})

with pytest.raises(ValueError, match="radius must be > 0"):
NucleusDetector.nucleus_detection_nms(df, radius=0)


def test_nucleus_detection_nms_invalid_overlap_threshold() -> None:
"""overlap_threshold must lie in (0, 1]."""
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})

message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
with pytest.raises(ValueError, match=message):
NucleusDetector.nucleus_detection_nms(df, radius=1, overlap_threshold=0)


def test_nucleus_detection_nms_suppresses_overlapping_detections() -> None:
"""Lower-probability overlapping detections are removed."""
df = pd.DataFrame(
{
"x": [2, 0, 20],
"y": [1, 0, 20],
"type": [1, 1, 2],
"prob": [0.6, 0.9, 0.7],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5)

expected = pd.DataFrame(
{"x": [0, 20], "y": [0, 20], "type": [1, 2], "prob": [0.9, 0.7]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detection_nms_suppresses_across_types() -> None:
"""Overlapping detections of different types are also suppressed."""
df = pd.DataFrame(
{
"x": [0, 0, 20],
"y": [0, 0, 0],
"type": [1, 2, 1],
"prob": [0.6, 0.95, 0.4],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5)

expected = pd.DataFrame(
{"x": [0, 20], "y": [0, 0], "type": [2, 1], "prob": [0.95, 0.4]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detection_nms_retains_non_overlapping_candidates() -> None:
"""Detections with IoU below the threshold are preserved."""
df = pd.DataFrame(
{
"x": [0, 10],
"y": [0, 0],
"type": [1, 1],
"prob": [0.8, 0.5],
}
)

result = NucleusDetector.nucleus_detection_nms(df, radius=5, overlap_threshold=0.5)

expected = pd.DataFrame(
{"x": [0, 10], "y": [0, 0], "type": [1, 1], "prob": [0.8, 0.5]}
)
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)


def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None:
"""Test for nucleus detection engine."""
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))

pretrained_model = "mapde-conic"

save_dir = tmp_path

nucleus_detector = NucleusDetector(model=pretrained_model)
_ = nucleus_detector.run(
patch_mode=False,
device=device,
output_type="annotationstore",
memory_threshold=50,
images=[mini_wsi_svs],
save_dir=save_dir,
overwrite=True,
)

store = SQLiteStore.open(save_dir / "wsi4_512_512.db")
assert len(store.values()) == 281
store.close()

_rm_dir(save_dir)


def test_nucleus_detector_patch(
remote_sample: Callable, tmp_path: pathlib.Path
) -> None:
"""Test for nucleus detection engine in patch mode."""
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))

wsi_reader = WSIReader.open(mini_wsi_svs)
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")

pretrained_model = "mapde-conic"

save_dir = tmp_path

nucleus_detector = NucleusDetector(model=pretrained_model)
_ = nucleus_detector.run(
patch_mode=True,
device=device,
output_type="annotationstore",
memory_threshold=50,
images=[patch_1, patch_2],
save_dir=save_dir,
overwrite=True,
class_dict=None,
)

store_1 = SQLiteStore.open(save_dir / "0.db")
assert len(store_1.values()) == 270
store_1.close()

store_2 = SQLiteStore.open(save_dir / "1.db")
assert len(store_2.values()) == 52
store_2.close()

imwrite(save_dir / "patch_0.png", patch_1)
imwrite(save_dir / "patch_1.png", patch_2)
_ = nucleus_detector.run(
patch_mode=True,
device=device,
output_type="zarr",
memory_threshold=50,
images=[save_dir / "patch_0.png", save_dir / "patch_1.png"],
save_dir=save_dir,
overwrite=True,
)

store_1 = SQLiteStore.open(save_dir / "patch_0.db")
assert len(store_1.values()) == 270
store_1.close()

store_2 = SQLiteStore.open(save_dir / "patch_1.db")
assert len(store_2.values()) == 52
store_2.close()

_rm_dir(save_dir)


def test_nucleus_detector_write_centroid_maps(tmp_path: pathlib.Path) -> None:
"""Test for _write_centroid_maps function."""
detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))

store = NucleusDetector.write_centroid_maps_to_store(
detection_maps=detection_maps, class_dict=None
)
assert len(store.values()) == 0
store.close()

detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
detection_maps[10, 10, 0] = 1
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))
_ = NucleusDetector.write_centroid_maps_to_store(
detection_maps=detection_maps,
save_path=tmp_path / "test.db",
class_dict={0: "nucleus"},
)
store = SQLiteStore.open(tmp_path / "test.db")
assert len(store.values()) == 1
annotation = next(iter(store.values()))
print(annotation)
assert annotation.properties["type"] == "nucleus"
assert annotation.geometry.centroid.x == 10.0
assert annotation.geometry.centroid.y == 10.0
store.close()
31 changes: 30 additions & 1 deletion tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tiatoolbox.models import MapDe
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader
Expand Down Expand Up @@ -48,7 +49,35 @@ def test_functionality(remote_sample: Callable) -> None:
batch = torch.from_numpy(patch)[None]
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)

np.testing.assert_array_equal(xs[0:2], np.array([242, 192]))
np.testing.assert_array_equal(ys[0:2], np.array([10, 13]))

patch = reader.read_bounds(
(0, 0, 252, 252),
resolution=0.50,
units="mpp",
coord_space="resolution",
)

model, weights_path = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
block_info = {
0: {
"array-location": [
[0, 1],
[0, 1],
], # dummy block to test no valid detections
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([]))
np.testing.assert_array_equal(ys, np.array([]))

Path(weights_path).unlink()


Expand Down
36 changes: 33 additions & 3 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tiatoolbox.models import SCCNN
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
from tiatoolbox.utils import env_detection
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader
Expand Down Expand Up @@ -48,13 +49,42 @@ def test_functionality(remote_sample: Callable) -> None:
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
np.testing.assert_array_equal(output, np.array([[8, 7]]))
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)

np.testing.assert_array_equal(xs, np.array([8]))
np.testing.assert_array_equal(ys, np.array([7]))

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
np.testing.assert_array_equal(output, np.array([[7, 8]]))
block_info = {
0: {
"array-location": [[0, 31], [0, 31]],
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([7]))
np.testing.assert_array_equal(ys, np.array([8]))

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
block_info = {
0: {
"array-location": [
[0, 1],
[0, 1],
], # dummy block to test no valid detections
}
}
output = model.postproc(output[0], block_info=block_info)
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
np.testing.assert_array_equal(xs, np.array([]))
np.testing.assert_array_equal(ys, np.array([]))
Loading