Skip to content

Commit e24b9e8

Browse files
Jiaqi-Lvshaneahmedpre-commit-ci[bot]
authored
🆕 Add GrandQC Tissue Segmentation Model (#965)
## 🚀Summary This PR introduces a new **[GrandQC Tissue Detection Model](https://github.com/cpath-ukk/grandqc/tree/main)** for digital pathology quality control and integrates **EfficientNet-based encoder architecture** into the TIAToolbox framework. --- ## ✨Key Changes - **New Model Architecture** - Added `grandqc.py` implementing a UNet++ decoder with EfficientNet encoder for tissue segmentation. - Includes preprocessing (JPEG compression + ImageNet normalization), postprocessing (argmin-based mask generation), and batch inference utilities. - **EfficientNet Encoder** - Added `timm_efficientnet.py` providing configurable EfficientNet encoders with dilation support and custom input channels. - **Pretrained Model Config** - Updated `pretrained_model.yaml` to register `grandqc_tissue_detection_mpp10` with associated IO configuration. - Corrected `IOSegmentorConfig` references and adjusted resolutions for SCCNN models. - **Testing** - Added comprehensive unit tests for: - `GrandQCModel` functionality, preprocessing/postprocessing, and decoder blocks. - EfficientNet encoder utilities and scaling logic. ## Impact - Enables high-resolution tissue detection for WSI quality control using state-of-the-art architectures. - Improves flexibility for segmentation tasks with EfficientNet encoders. - Enhances code quality and consistency through updated linting and formatting tools. ## Tasks - [x] Re-host GrandQC model weights on TIA Hugging Face - [x] Update `pretrained_model.yaml` - [x] Update `requirements.txt` - [x] Define GrandQC model architecture - [x] Add example usage - [x] Remove segmentation-models-pytorch dependency - [x] Wait for response from GrandQC authors - [x] Add tests - [x] Tidy up --------- Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c535eab commit e24b9e8

File tree

5 files changed

+1831
-6
lines changed

5 files changed

+1831
-6
lines changed

tests/models/test_arch_grandqc.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Unit test package for GrandQC Tissue Model."""
2+
3+
from collections.abc import Callable
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from torch import nn
10+
11+
from tiatoolbox.annotation.storage import SQLiteStore
12+
from tiatoolbox.models.architecture import (
13+
fetch_pretrained_weights,
14+
get_pretrained_model,
15+
)
16+
from tiatoolbox.models.architecture.grandqc import (
17+
CenterBlock,
18+
GrandQCModel,
19+
SegmentationHead,
20+
UnetPlusPlusDecoder,
21+
)
22+
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
23+
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
24+
from tiatoolbox.utils import env_detection as toolbox_env
25+
from tiatoolbox.wsicore.wsireader import VirtualWSIReader
26+
27+
device = "cuda" if toolbox_env.has_gpu() else "cpu"
28+
29+
30+
def test_functional_grandqc() -> None:
31+
"""Test for GrandQC model."""
32+
# test fetch pretrained weights
33+
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection")
34+
assert pretrained_weights is not None
35+
36+
# test creation
37+
model = GrandQCModel(num_output_channels=2)
38+
assert model is not None
39+
40+
# load pretrained weights
41+
pretrained = torch.load(pretrained_weights, map_location=device)
42+
model.load_state_dict(pretrained)
43+
44+
# test get pretrained model
45+
model, ioconfig = get_pretrained_model("grandqc_tissue_detection")
46+
assert isinstance(model, GrandQCModel)
47+
assert isinstance(ioconfig, IOSegmentorConfig)
48+
assert model.num_output_channels == 2
49+
assert model.decoder_channels == (256, 128, 64, 32, 16)
50+
51+
# test inference
52+
generator = np.random.default_rng(1337)
53+
test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8)
54+
reader = VirtualWSIReader.open(test_image)
55+
read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"}
56+
batch = np.array(
57+
[
58+
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
59+
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
60+
],
61+
)
62+
batch = torch.from_numpy(batch)
63+
output = model.infer_batch(model, batch, device=device)
64+
assert output.shape == (2, 512, 512, 2)
65+
66+
67+
def test_grandqc_preproc_postproc() -> None:
68+
"""Test GrandQC preproc and postproc functions."""
69+
model = GrandQCModel(num_output_channels=2)
70+
71+
generator = np.random.default_rng(1337)
72+
# test preproc
73+
dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8)
74+
preproc_image = model.preproc(dummy_image)
75+
assert preproc_image.shape == dummy_image.shape
76+
assert preproc_image.dtype == np.float64
77+
78+
# test postproc
79+
dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32)
80+
postproc_image = model.postproc(dummy_output)
81+
assert postproc_image.shape == (512, 512)
82+
assert postproc_image.dtype == np.int64
83+
84+
85+
def test_grandqc_with_semantic_segmentor(
86+
remote_sample: Callable, track_tmp_path: Path
87+
) -> None:
88+
"""Test GrandQC tissue mask generation."""
89+
segmentor = SemanticSegmentor(model="grandqc_tissue_detection")
90+
91+
sample_image = remote_sample("svs-1-small")
92+
inputs = [str(sample_image)]
93+
94+
output = segmentor.run(
95+
images=inputs,
96+
device=device,
97+
patch_mode=False,
98+
output_type="annotationstore",
99+
save_dir=track_tmp_path / "grandqc_test_outputs",
100+
overwrite=True,
101+
)
102+
103+
assert len(output) == 1
104+
assert Path(output[sample_image]).exists()
105+
106+
store = SQLiteStore.open(output[sample_image])
107+
assert len(store) == 3
108+
109+
tissue_area_px = 0.0
110+
for annotation in store.values():
111+
assert annotation.properties["type"] == "mask"
112+
tissue_area_px += annotation.geometry.area
113+
assert 2999000 < tissue_area_px < 3004000
114+
115+
store.close()
116+
117+
118+
def test_segmentation_head_behaviour() -> None:
119+
"""Verify SegmentationHead defaults and upsampling."""
120+
head = SegmentationHead(3, 5, activation=None, upsampling=1)
121+
assert isinstance(head[1], nn.Identity)
122+
assert isinstance(head[2], nn.Identity)
123+
124+
x = torch.randn(1, 3, 6, 8)
125+
out = head(x)
126+
assert out.shape == (1, 5, 6, 8)
127+
128+
head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2)
129+
x = torch.ones(1, 3, 4, 4)
130+
out = head(x)
131+
assert out.shape == (1, 2, 8, 8)
132+
assert torch.all(out >= 0)
133+
assert torch.all(out <= 1)
134+
135+
136+
def test_unetplusplus_decoder_forward_shapes() -> None:
137+
"""Ensure UnetPlusPlusDecoder handles dense connections."""
138+
decoder = UnetPlusPlusDecoder(
139+
encoder_channels=[1, 2, 4, 8],
140+
decoder_channels=[8, 4, 2],
141+
n_blocks=3,
142+
)
143+
144+
features = [
145+
torch.randn(1, 1, 32, 32),
146+
torch.randn(1, 2, 16, 16),
147+
torch.randn(1, 4, 8, 8),
148+
torch.randn(1, 8, 4, 4),
149+
]
150+
151+
output = decoder(features)
152+
assert output.shape == (1, 2, 32, 32)
153+
154+
155+
def test_center_block_behavior() -> None:
156+
"""Test CenterBlock behavior in UnetPlusPlusDecoder."""
157+
center_block = CenterBlock(in_channels=8, out_channels=8)
158+
159+
x = torch.randn(1, 8, 4, 4)
160+
out = center_block(x)
161+
assert out.shape == (1, 8, 4, 4)
162+
163+
164+
def test_unetpp_raises_value_error() -> None:
165+
"""Test UnetPlusPlusDecoder raises ValueError."""
166+
with pytest.raises(
167+
ValueError, match=r".*depth is 4, but you provide `decoder_channels` for 3.*"
168+
):
169+
_ = UnetPlusPlusDecoder(
170+
encoder_channels=[1, 2, 4, 8],
171+
decoder_channels=[8, 4, 2],
172+
n_blocks=4,
173+
)

0 commit comments

Comments
 (0)