Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b18b98f
add grandqc tissue model
Jiaqi-Lv Oct 25, 2025
899d6cb
add example
Jiaqi-Lv Oct 25, 2025
8a7295d
fix tests
Jiaqi-Lv Oct 25, 2025
5c5bfc4
fix error
Jiaqi-Lv Oct 25, 2025
fd692da
update docstring
Jiaqi-Lv Oct 28, 2025
d82cc3d
improve test coverage
Jiaqi-Lv Oct 28, 2025
93a24a1
add unet++ model
Jiaqi-Lv Nov 6, 2025
2d076c0
Merge branch 'dev-define-engines-abc' into dev-add-grandQC
shaneahmed Nov 17, 2025
283b888
Merge branch 'dev-add-grandQC' of https://github.com/TissueImageAnaly…
Jiaqi-Lv Nov 18, 2025
94c43ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
98cef83
remove smp dependency
Jiaqi-Lv Nov 18, 2025
d47fa0a
refactor code
Jiaqi-Lv Nov 18, 2025
d2a66ca
add tests
Jiaqi-Lv Nov 21, 2025
19cca90
address comments
Jiaqi-Lv Nov 21, 2025
1895e38
:memo: Update docstring for grandqc.py and timm_efficientnet.py
shaneahmed Nov 25, 2025
3ade99a
:bug: Fix docstring
shaneahmed Nov 25, 2025
5f0202f
:memo: Remove duplicate docstring for classses.
shaneahmed Nov 25, 2025
6b8eb90
address comments
Jiaqi-Lv Nov 25, 2025
2ce379f
update test
Jiaqi-Lv Nov 25, 2025
9c62b72
:white_check_mark: Add test to improve coverage
shaneahmed Nov 26, 2025
d1ce4a0
improve test coverage
Jiaqi-Lv Nov 27, 2025
717b0ff
improve test coverage
Jiaqi-Lv Nov 27, 2025
3cc5924
address comments
Jiaqi-Lv Nov 28, 2025
1ab6728
:fire: Remove unnecessary checks
shaneahmed Dec 1, 2025
3a27ed6
:bug: Fix incorrect input for bias
shaneahmed Dec 1, 2025
cc4499a
:art: Improve structure of the code.
shaneahmed Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Unit test package for GrandQC Tissue Model."""

from collections.abc import Callable
from pathlib import Path

import numpy as np
import pytest
import torch
from torch import nn

from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import (
CenterBlock,
GrandQCModel,
SegmentationHead,
UnetPlusPlusDecoder,
)
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.wsicore.wsireader import VirtualWSIReader

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_functional_grandqc() -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection")
assert pretrained_weights is not None

# test creation
model = GrandQCModel(num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location=device)
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection")
assert isinstance(model, GrandQCModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_output_channels == 2
assert model.decoder_channels == (256, 128, 64, 32, 16)

# test inference
generator = np.random.default_rng(1337)
test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8)
reader = VirtualWSIReader.open(test_image)
read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=device)
assert output.shape == (2, 512, 512, 2)


def test_grandqc_preproc_postproc() -> None:
"""Test GrandQC preproc and postproc functions."""
model = GrandQCModel(num_output_channels=2)

generator = np.random.default_rng(1337)
# test preproc
dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8)
preproc_image = model.preproc(dummy_image)
assert preproc_image.shape == dummy_image.shape
assert preproc_image.dtype == np.float64

# test postproc
dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32)
postproc_image = model.postproc(dummy_output)
assert postproc_image.shape == (512, 512)
assert postproc_image.dtype == np.int64


def test_grandqc_with_semantic_segmentor(
remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Test GrandQC tissue mask generation."""
segmentor = SemanticSegmentor(model="grandqc_tissue_detection")

sample_image = remote_sample("svs-1-small")
inputs = [str(sample_image)]

output = segmentor.run(
images=inputs,
device=device,
patch_mode=False,
output_type="annotationstore",
save_dir=track_tmp_path / "grandqc_test_outputs",
overwrite=True,
)

assert len(output) == 1
assert Path(output[sample_image]).exists()

store = SQLiteStore.open(output[sample_image])
assert len(store) == 3

tissue_area_px = 0.0
for annotation in store.values():
assert annotation.properties["type"] == "mask"
tissue_area_px += annotation.geometry.area
assert 2999000 < tissue_area_px < 3004000

store.close()


def test_segmentation_head_behaviour() -> None:
"""Verify SegmentationHead defaults and upsampling."""
head = SegmentationHead(3, 5, activation=None, upsampling=1)
assert isinstance(head[1], nn.Identity)
assert isinstance(head[2], nn.Identity)

x = torch.randn(1, 3, 6, 8)
out = head(x)
assert out.shape == (1, 5, 6, 8)

head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2)
x = torch.ones(1, 3, 4, 4)
out = head(x)
assert out.shape == (1, 2, 8, 8)
assert torch.all(out >= 0)
assert torch.all(out <= 1)


def test_unetplusplus_decoder_forward_shapes() -> None:
"""Ensure UnetPlusPlusDecoder handles dense connections."""
decoder = UnetPlusPlusDecoder(
encoder_channels=[1, 2, 4, 8],
decoder_channels=[8, 4, 2],
n_blocks=3,
)

features = [
torch.randn(1, 1, 32, 32),
torch.randn(1, 2, 16, 16),
torch.randn(1, 4, 8, 8),
torch.randn(1, 8, 4, 4),
]

output = decoder(features)
assert output.shape == (1, 2, 32, 32)


def test_center_block_behavior() -> None:
"""Test CenterBlock behavior in UnetPlusPlusDecoder."""
center_block = CenterBlock(in_channels=8, out_channels=8)

x = torch.randn(1, 8, 4, 4)
out = center_block(x)
assert out.shape == (1, 8, 4, 4)


def test_unetpp_raises_value_error() -> None:
"""Test UnetPlusPlusDecoder raises ValueError."""
with pytest.raises(
ValueError, match=r".*depth is 4, but you provide `decoder_channels` for 3.*"
):
_ = UnetPlusPlusDecoder(
encoder_channels=[1, 2, 4, 8],
decoder_channels=[8, 4, 2],
n_blocks=4,
)
Loading