Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b18b98f
add grandqc tissue model
Jiaqi-Lv Oct 25, 2025
899d6cb
add example
Jiaqi-Lv Oct 25, 2025
8a7295d
fix tests
Jiaqi-Lv Oct 25, 2025
5c5bfc4
fix error
Jiaqi-Lv Oct 25, 2025
fd692da
update docstring
Jiaqi-Lv Oct 28, 2025
d82cc3d
improve test coverage
Jiaqi-Lv Oct 28, 2025
93a24a1
add unet++ model
Jiaqi-Lv Nov 6, 2025
2d076c0
Merge branch 'dev-define-engines-abc' into dev-add-grandQC
shaneahmed Nov 17, 2025
283b888
Merge branch 'dev-add-grandQC' of https://github.com/TissueImageAnaly…
Jiaqi-Lv Nov 18, 2025
94c43ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
98cef83
remove smp dependency
Jiaqi-Lv Nov 18, 2025
d47fa0a
refactor code
Jiaqi-Lv Nov 18, 2025
d2a66ca
add tests
Jiaqi-Lv Nov 21, 2025
19cca90
address comments
Jiaqi-Lv Nov 21, 2025
1895e38
:memo: Update docstring for grandqc.py and timm_efficientnet.py
shaneahmed Nov 25, 2025
3ade99a
:bug: Fix docstring
shaneahmed Nov 25, 2025
5f0202f
:memo: Remove duplicate docstring for classses.
shaneahmed Nov 25, 2025
6b8eb90
address comments
Jiaqi-Lv Nov 25, 2025
2ce379f
update test
Jiaqi-Lv Nov 25, 2025
9c62b72
:white_check_mark: Add test to improve coverage
shaneahmed Nov 26, 2025
d1ce4a0
improve test coverage
Jiaqi-Lv Nov 27, 2025
717b0ff
improve test coverage
Jiaqi-Lv Nov 27, 2025
3cc5924
address comments
Jiaqi-Lv Nov 28, 2025
1ab6728
:fire: Remove unnecessary checks
shaneahmed Dec 1, 2025
3a27ed6
:bug: Fix incorrect input for bias
shaneahmed Dec 1, 2025
cc4499a
:art: Improve structure of the code.
shaneahmed Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Unit test package for GrandQC Tissue Model."""

import numpy as np
import torch
from torch import nn

from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import (
CenterBlock,
GrandQCModel,
SegmentationHead,
UnetPlusPlusDecoder,
)
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import VirtualWSIReader

ON_GPU = False


def test_functional_grandqc() -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10")
assert pretrained_weights is not None

# test creation
model = GrandQCModel(num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
assert isinstance(model, GrandQCModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_output_channels == 2
assert model.decoder_channels == (256, 128, 64, 32, 16)

# test inference
generator = np.random.default_rng(1337)
test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8)
reader = VirtualWSIReader.open(test_image)
read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
assert output.shape == (2, 512, 512, 2)


def test_grandqc_preproc_postproc() -> None:
"""Test GrandQC preproc and postproc functions."""
model = GrandQCModel(num_output_channels=2)

generator = np.random.default_rng(1337)
# test preproc
dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8)
preproc_image = model.preproc(dummy_image)
assert preproc_image.shape == dummy_image.shape
assert preproc_image.dtype == np.float64

# test postproc
dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32)
postproc_image = model.postproc(dummy_output)
assert postproc_image.shape == (512, 512)
assert postproc_image.dtype == np.int64


def test_segmentation_head_behaviour() -> None:
"""Verify SegmentationHead defaults and upsampling."""
head = SegmentationHead(3, 5, activation=None, upsampling=1)
assert isinstance(head[1], nn.Identity)
assert isinstance(head[2], nn.Identity)

x = torch.randn(1, 3, 6, 8)
out = head(x)
assert out.shape == (1, 5, 6, 8)

head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2)
x = torch.ones(1, 3, 4, 4)
out = head(x)
assert out.shape == (1, 2, 8, 8)
assert torch.all(out >= 0)
assert torch.all(out <= 1)


def test_unetplusplus_decoder_forward_shapes() -> None:
"""Ensure UnetPlusPlusDecoder handles dense connections."""
decoder = UnetPlusPlusDecoder(
encoder_channels=[1, 2, 4, 8],
decoder_channels=[8, 4, 2],
n_blocks=3,
)

features = [
torch.randn(1, 1, 32, 32),
torch.randn(1, 2, 16, 16),
torch.randn(1, 4, 8, 8),
torch.randn(1, 8, 4, 4),
]

output = decoder(features)
assert output.shape == (1, 2, 32, 32)


def test_center_block_behavior() -> None:
"""Test CenterBlock behavior in UnetPlusPlusDecoder."""
center_block = CenterBlock(in_channels=8, out_channels=8)

x = torch.randn(1, 8, 4, 4)
out = center_block(x)
assert out.shape == (1, 8, 4, 4)
177 changes: 177 additions & 0 deletions tests/models/test_arch_timm_efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Unit tests for timm EfficientNet encoder helpers."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence

import pytest
import torch
from torch import nn

from tiatoolbox.models.architecture import timm_efficientnet as effnet_mod
from tiatoolbox.models.architecture.timm_efficientnet import (
DEFAULT_IN_CHANNELS,
EfficientNetEncoder,
EncoderMixin,
replace_strides_with_dilation,
)


class DummyEncoder(nn.Module, EncoderMixin):
"""Lightweight encoder for testing mixin behavior."""

def __init__(self) -> None:
"""Initialize EncoderMixin for testing."""
nn.Module.__init__(self)
EncoderMixin.__init__(self)
self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
self.conv32 = nn.Conv2d(4, 4, 3)
self._out_channels = [DEFAULT_IN_CHANNELS, 4, 8]
self._depth = 2

def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]:
"""Get stages for dilation modification.

Returns:
Dictionary with keys as output stride and values as list of modules.
"""
return {16: [self.conv], 32: [self.conv32]}


def test_patch_first_conv() -> None:
"""patch_first_conv should reduce or expand correctly."""
# create simple conv
model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False))
conv = model[0]

# collapsing 3 channels into 1
effnet_mod.patch_first_conv(model, new_in_channels=1, pretrained=True)
assert conv.in_channels == 1

# expanding to 5 channels
model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False))
conv = model[0]

effnet_mod.patch_first_conv(model, new_in_channels=5, pretrained=True)
assert conv.in_channels == 5


def test_patch_first_conv_reset_weights_when_not_pretrained() -> None:
"""Ensure random reinit happens when pretrained flag is False."""
# start from known weights
model = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1, bias=False))
original = model[0].weight.clone()
# changing channel count without pretrained should reinit parameters
effnet_mod.patch_first_conv(model, new_in_channels=4, pretrained=False)
assert model[0].in_channels == 4
assert model[0].weight.shape[1] == 4
# Almost surely changed due to reset_parameters
assert not torch.equal(original, model[0].weight[:1, :3])


def test_patch_first_conv_no_matching_layer_is_safe() -> None:
"""The function should silently exit when no suitable conv exists."""
model = nn.Sequential(nn.Conv2d(5, 1, kernel_size=1))
original = model[0].weight.clone()
# no conv with default channel count, so weights stay unchanged
effnet_mod.patch_first_conv(model, new_in_channels=3, pretrained=True)
assert torch.equal(original, model[0].weight)


def test_replace_strides_with_dilation_applies_to_nested_convs() -> None:
"""Strides become dilation and static padding gets removed."""
module = nn.Sequential(
nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1),
)
# attach static_padding to mirror EfficientNet convs
module[0].static_padding = nn.Conv2d(1, 1, 1)

# applying dilation should also strip static padding
replace_strides_with_dilation(module, dilation_rate=3)
conv = module[0]
assert conv.stride == (1, 1)
assert conv.dilation == (3, 3)
assert conv.padding == (3, 3)
assert isinstance(conv.static_padding, nn.Identity)


def test_encoder_mixin_properties_and_set_in_channels() -> None:
"""EncoderMixin should expose out_channels/output_stride and patch convs."""
# use dummy encoder to check property logic
encoder = DummyEncoder()
assert encoder.out_channels == [3, 4, 8]
# adjust internals to check min logic in output_stride
encoder._output_stride = 4
encoder._depth = 3
assert encoder.output_stride == 4 # min(output_stride, 2**depth)

# calling set_in_channels should patch first conv and update bookkeeping
encoder.set_in_channels(5, pretrained=False)
assert encoder._in_channels == 5
assert encoder.out_channels[0] == 5
assert encoder.conv.in_channels == 5


def test_encoder_mixin_make_dilated_and_validation() -> None:
"""make_dilated should error on invalid stride and patch convs otherwise."""
encoder = DummyEncoder()

# invalid stride raises
with pytest.raises(ValueError, match="Output stride should be 16 or 8"):
encoder.make_dilated(output_stride=4)

# valid stride should touch both stage groups
encoder.make_dilated(output_stride=8)
conv16, conv32 = encoder.get_stages()[16][0], encoder.get_stages()[32][0]
assert conv16.stride == (1, 1)
assert conv16.dilation == (2, 2)
assert conv32.stride == (1, 1)
assert conv32.dilation == (4, 4)


def test_get_efficientnet_kwargs_shapes_and_values() -> None:
"""get_efficientnet_kwargs should produce expected keys and scaling."""
# confirm output contains decoded blocks and scaled channels
kwargs = effnet_mod.get_efficientnet_kwargs(
channel_multiplier=1.2, depth_multiplier=1.4, drop_rate=0.3
)
assert kwargs.get("block_args")
assert kwargs["num_features"] == effnet_mod.round_channels(1280, 1.2, 8, None)
assert kwargs["drop_rate"] == 0.3


def test_efficientnet_encoder_depth_validation_and_forward() -> None:
"""EfficientNetEncoder should validate depth and run forward returning features."""
# invalid depth should fail fast
with pytest.raises(
ValueError, match=r"EfficientNetEncoder depth should be in range\s+\[1, 5\]"
):
EfficientNetEncoder(
stage_idxs=[2, 3, 5],
out_channels=[3, 32, 24, 40, 112, 320],
depth=6,
)

# build shallow encoder and run a forward pass
encoder = EfficientNetEncoder(
stage_idxs=[2, 3, 5],
out_channels=[3, 32, 24, 40, 112, 320],
depth=3,
channel_multiplier=0.5,
depth_multiplier=0.5,
)
x = torch.randn(1, 3, 32, 32)
features = encoder(x)
assert len(features) == encoder._depth + 1
assert torch.equal(features[0], x)

# ensure classifier keys are dropped before loading into the model
extended_state = dict(encoder.state_dict())
extended_state["classifier.bias"] = torch.tensor([1.0])
extended_state["classifier.weight"] = torch.tensor([[1.0]])
load_result = encoder.load_state_dict(extended_state, strict=True)
assert not load_result.missing_keys
assert not load_result.unexpected_keys
30 changes: 24 additions & 6 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ mapde-crchisto:
threshold_abs: 250
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -837,7 +837,7 @@ mapde-conic:
threshold_abs: 205
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -860,7 +860,7 @@ sccnn-crchisto:
threshold_abs: 0.20
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.25 }
Expand All @@ -883,7 +883,7 @@ sccnn-conic:
threshold_abs: 0.05
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.25 }
Expand All @@ -903,7 +903,7 @@ nuclick_original-pannuke:
num_input_channels: 5
num_output_channels: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -925,7 +925,7 @@ nuclick_light-pannuke:
decoder_block: [3,3]
skip_type: "add"
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -934,3 +934,21 @@ nuclick_light-pannuke:
patch_input_shape: [128, 128]
patch_output_shape: [128, 128]
save_resolution: {'units': 'baseline', 'resolution': 1.0}

grandqc_tissue_detection_mpp10:
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
architecture:
class: grandqc.GrandQCModel
kwargs:
num_output_channels: 2
ioconfig:
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
output_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
patch_input_shape: [512, 512]
patch_output_shape: [512, 512]
stride_shape: [256, 256]
save_resolution: {'units': 'mpp', 'resolution': 10.0}
Loading