diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py new file mode 100644 index 000000000..3753e9335 --- /dev/null +++ b/tests/models/test_arch_grandqc.py @@ -0,0 +1,173 @@ +"""Unit test package for GrandQC Tissue Model.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn + +from tiatoolbox.annotation.storage import SQLiteStore +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.grandqc import ( + CenterBlock, + GrandQCModel, + SegmentationHead, + UnetPlusPlusDecoder, +) +from tiatoolbox.models.engine.io_config import IOSegmentorConfig +from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.wsicore.wsireader import VirtualWSIReader + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_functional_grandqc() -> None: + """Test for GrandQC model.""" + # test fetch pretrained weights + pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection") + assert pretrained_weights is not None + + # test creation + model = GrandQCModel(num_output_channels=2) + assert model is not None + + # load pretrained weights + pretrained = torch.load(pretrained_weights, map_location=device) + model.load_state_dict(pretrained) + + # test get pretrained model + model, ioconfig = get_pretrained_model("grandqc_tissue_detection") + assert isinstance(model, GrandQCModel) + assert isinstance(ioconfig, IOSegmentorConfig) + assert model.num_output_channels == 2 + assert model.decoder_channels == (256, 128, 64, 32, 16) + + # test inference + generator = np.random.default_rng(1337) + test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8) + reader = VirtualWSIReader.open(test_image) + read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"} + batch = np.array( + [ + reader.read_bounds((0, 0, 512, 512), **read_kwargs), + reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), + ], + ) + batch = torch.from_numpy(batch) + output = model.infer_batch(model, batch, device=device) + assert output.shape == (2, 512, 512, 2) + + +def test_grandqc_preproc_postproc() -> None: + """Test GrandQC preproc and postproc functions.""" + model = GrandQCModel(num_output_channels=2) + + generator = np.random.default_rng(1337) + # test preproc + dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8) + preproc_image = model.preproc(dummy_image) + assert preproc_image.shape == dummy_image.shape + assert preproc_image.dtype == np.float64 + + # test postproc + dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32) + postproc_image = model.postproc(dummy_output) + assert postproc_image.shape == (512, 512) + assert postproc_image.dtype == np.int64 + + +def test_grandqc_with_semantic_segmentor( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Test GrandQC tissue mask generation.""" + segmentor = SemanticSegmentor(model="grandqc_tissue_detection") + + sample_image = remote_sample("svs-1-small") + inputs = [str(sample_image)] + + output = segmentor.run( + images=inputs, + device=device, + patch_mode=False, + output_type="annotationstore", + save_dir=track_tmp_path / "grandqc_test_outputs", + overwrite=True, + ) + + assert len(output) == 1 + assert Path(output[sample_image]).exists() + + store = SQLiteStore.open(output[sample_image]) + assert len(store) == 3 + + tissue_area_px = 0.0 + for annotation in store.values(): + assert annotation.properties["type"] == "mask" + tissue_area_px += annotation.geometry.area + assert 2999000 < tissue_area_px < 3004000 + + store.close() + + +def test_segmentation_head_behaviour() -> None: + """Verify SegmentationHead defaults and upsampling.""" + head = SegmentationHead(3, 5, activation=None, upsampling=1) + assert isinstance(head[1], nn.Identity) + assert isinstance(head[2], nn.Identity) + + x = torch.randn(1, 3, 6, 8) + out = head(x) + assert out.shape == (1, 5, 6, 8) + + head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2) + x = torch.ones(1, 3, 4, 4) + out = head(x) + assert out.shape == (1, 2, 8, 8) + assert torch.all(out >= 0) + assert torch.all(out <= 1) + + +def test_unetplusplus_decoder_forward_shapes() -> None: + """Ensure UnetPlusPlusDecoder handles dense connections.""" + decoder = UnetPlusPlusDecoder( + encoder_channels=[1, 2, 4, 8], + decoder_channels=[8, 4, 2], + n_blocks=3, + ) + + features = [ + torch.randn(1, 1, 32, 32), + torch.randn(1, 2, 16, 16), + torch.randn(1, 4, 8, 8), + torch.randn(1, 8, 4, 4), + ] + + output = decoder(features) + assert output.shape == (1, 2, 32, 32) + + +def test_center_block_behavior() -> None: + """Test CenterBlock behavior in UnetPlusPlusDecoder.""" + center_block = CenterBlock(in_channels=8, out_channels=8) + + x = torch.randn(1, 8, 4, 4) + out = center_block(x) + assert out.shape == (1, 8, 4, 4) + + +def test_unetpp_raises_value_error() -> None: + """Test UnetPlusPlusDecoder raises ValueError.""" + with pytest.raises( + ValueError, match=r".*depth is 4, but you provide `decoder_channels` for 3.*" + ): + _ = UnetPlusPlusDecoder( + encoder_channels=[1, 2, 4, 8], + decoder_channels=[8, 4, 2], + n_blocks=4, + ) diff --git a/tests/models/test_arch_timm_efficientnet.py b/tests/models/test_arch_timm_efficientnet.py new file mode 100644 index 000000000..147e4acab --- /dev/null +++ b/tests/models/test_arch_timm_efficientnet.py @@ -0,0 +1,282 @@ +"""Unit tests for timm EfficientNet encoder helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + +import pytest +import torch +from torch import nn + +from tiatoolbox.models.architecture import timm_efficientnet as effnet_mod +from tiatoolbox.models.architecture.timm_efficientnet import ( + DEFAULT_IN_CHANNELS, + EfficientNetEncoder, + EncoderMixin, + replace_strides_with_dilation, +) + + +class DummyEncoder(nn.Module, EncoderMixin): + """Lightweight encoder for testing mixin behavior.""" + + def __init__(self) -> None: + """Initialize EncoderMixin for testing.""" + nn.Module.__init__(self) + EncoderMixin.__init__(self) + self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1) + self.conv32 = nn.Conv2d(4, 4, 3) + self._out_channels = [DEFAULT_IN_CHANNELS, 4, 8] + self._depth = 2 + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Returns: + Dictionary with keys as output stride and values as list of modules. + """ + return {16: [self.conv], 32: [self.conv32]} + + +def test_patch_first_conv() -> None: + """patch_first_conv should reduce or expand correctly.""" + # create simple conv + model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) + conv = model[0] + + # collapsing 3 channels into 1 + effnet_mod.patch_first_conv(model, new_in_channels=1, pretrained=True) + assert conv.in_channels == 1 + + # expanding to 5 channels + model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) + conv = model[0] + + effnet_mod.patch_first_conv(model, new_in_channels=5, pretrained=True) + assert conv.in_channels == 5 + + +def test_patch_first_conv_reset_weights_when_not_pretrained() -> None: + """Ensure random reinit happens when pretrained flag is False.""" + # start from known weights + model = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1, bias=False)) + original = model[0].weight.clone() + # changing channel count without pretrained should reinit parameters + effnet_mod.patch_first_conv(model, new_in_channels=4, pretrained=False) + assert model[0].in_channels == 4 + assert model[0].weight.shape[1] == 4 + # Almost surely changed due to reset_parameters + assert not torch.equal(original, model[0].weight[:1, :3]) + + +def test_patch_first_conv_no_matching_layer_is_safe() -> None: + """The function should silently exit when no suitable conv exists.""" + model = nn.Sequential(nn.Conv2d(5, 1, kernel_size=1)) + original = model[0].weight.clone() + # no conv with default channel count, so weights stay unchanged + effnet_mod.patch_first_conv(model, new_in_channels=3, pretrained=True) + assert torch.equal(original, model[0].weight) + + +def test_replace_strides_with_dilation_applies_to_nested_convs() -> None: + """Strides become dilation and static padding gets removed.""" + module = nn.Sequential( + nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1), + ) + # attach static_padding to mirror EfficientNet convs + module[0].static_padding = nn.Conv2d(1, 1, 1) + + # applying dilation should also strip static padding + replace_strides_with_dilation(module, dilation_rate=3) + conv = module[0] + assert conv.stride == (1, 1) + assert conv.dilation == (3, 3) + assert conv.padding == (3, 3) + assert isinstance(conv.static_padding, nn.Identity) + + +def test_encoder_mixin_properties_and_set_in_channels() -> None: + """EncoderMixin should expose out_channels/output_stride and patch convs.""" + # use dummy encoder to check property logic + encoder = DummyEncoder() + assert encoder.out_channels == [3, 4, 8] + # adjust internals to check min logic in output_stride + encoder._output_stride = 4 + encoder._depth = 3 + assert encoder.output_stride == 4 # min(output_stride, 2**depth) + + # calling set_in_channels should patch first conv and update bookkeeping + encoder.set_in_channels(5, pretrained=False) + assert encoder._in_channels == 5 + assert encoder.out_channels[0] == 5 + assert encoder.conv.in_channels == 5 + + +def test_set_in_channels_noop_for_default() -> None: + """Calling with DEFAULT_IN_CHANNELS should skip patching.""" + encoder = DummyEncoder() + encoder.set_in_channels(DEFAULT_IN_CHANNELS, pretrained=True) + assert encoder._in_channels == DEFAULT_IN_CHANNELS + + +def test_set_in_channels_modify_out_channels() -> None: + """First output channels should change when in_channels is modified.""" + encoder = DummyEncoder() + encoder._out_channels[0] = DEFAULT_IN_CHANNELS + + encoder.set_in_channels(5, pretrained=False) + + assert encoder._out_channels[0] == 5 + assert encoder._in_channels == 5 + + +def test_set_in_channels_preserves_custom_out_channels() -> None: + """When first out_channels is customized, set_in_channels should not override.""" + encoder = DummyEncoder() + encoder._out_channels[0] = 7 + + encoder.set_in_channels(5, pretrained=False) + + assert encoder._out_channels[0] == 7 + assert encoder._in_channels == 5 + assert encoder.conv.in_channels == 5 + + +def test_encoder_mixin_make_dilated_and_validation() -> None: + """make_dilated should error on invalid stride and patch convs otherwise.""" + encoder = DummyEncoder() + + # invalid stride raises + with pytest.raises(ValueError, match="Output stride should be 16 or 8"): + encoder.make_dilated(output_stride=4) + + # valid stride should touch both stage groups + encoder.make_dilated(output_stride=8) + conv16, conv32 = encoder.get_stages()[16][0], encoder.get_stages()[32][0] + assert conv16.stride == (1, 1) + assert conv16.dilation == (2, 2) + assert conv32.stride == (1, 1) + assert conv32.dilation == (4, 4) + + +def test_make_dilated_skips_stages_below_output_stride() -> None: + """Stages at or below the target stride should be left untouched.""" + encoder = DummyEncoder() + encoder.conv.stride = (2, 2) # stage_stride == 16, so should be skipped + encoder.conv.dilation = (1, 1) + + encoder.make_dilated(output_stride=16) + + # stage at stride 16 skipped + assert encoder.conv.stride == (2, 2) + assert encoder.conv.dilation == (1, 1) + + # stage at stride 32 modified + conv32 = encoder.get_stages()[32][0] + assert conv32.dilation == (2, 2) + assert conv32.padding == (2, 2) + + +def test_efficientnet_encoder_get_stages_splits_blocks() -> None: + """Test get_stages for dilation modification.""" + encoder = EfficientNetEncoder( + stage_idxs=[1, 2, 4], + out_channels=[3, 8, 16, 32, 64, 128], + depth=3, + channel_multiplier=1.0, + depth_multiplier=1.0, + ) + stages = encoder.get_stages() + assert len(stages) == 2 + assert stages.keys() == {16, 32} + + +def test_get_efficientnet_kwargs_shapes_and_values() -> None: + """get_efficientnet_kwargs should produce expected keys and scaling.""" + # confirm output contains decoded blocks and scaled channels + kwargs = effnet_mod.get_efficientnet_kwargs( + channel_multiplier=1.2, depth_multiplier=1.4, drop_rate=0.3 + ) + assert kwargs.get("block_args") + assert kwargs["num_features"] == effnet_mod.round_channels(1280, 1.2, 8, None) + assert kwargs["drop_rate"] == 0.3 + + +def test_efficientnet_encoder_depth_validation_and_forward() -> None: + """EfficientNetEncoder should validate depth and run forward returning features.""" + # invalid depth should fail fast + with pytest.raises( + ValueError, match=r"EfficientNetEncoder depth should be in range\s+\[1, 5\]" + ): + EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=6, + ) + + # build shallow encoder and run a forward pass + encoder = EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=3, + channel_multiplier=1.0, + depth_multiplier=1.0, + ) + x = torch.randn(1, 3, 32, 32) + features = encoder(x) + assert len(features) == encoder._depth + 1 + assert torch.equal(features[0], x) + # cover depth-gated forward branches up to depth 3 + assert features[1].shape[1] == 32 + assert features[2].shape[1] == 24 + assert features[3].shape[1] == 40 + + encoder = EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=1, + channel_multiplier=1.0, + depth_multiplier=1.0, + ) + x = torch.randn(1, 3, 32, 32) + features = encoder(x) + assert len(features) == encoder._depth + 1 + assert torch.equal(features[0], x) + assert features[1].shape[1] == 32 + + encoder = EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=2, + channel_multiplier=1.0, + depth_multiplier=1.0, + ) + x = torch.randn(1, 3, 32, 32) + features = encoder(x) + assert len(features) == encoder._depth + 1 + assert torch.equal(features[0], x) + assert features[1].shape[1] == 32 + assert features[2].shape[1] == 24 + + +def test_efficientnet_encoder_load_state_dict_drops_classifier_keys() -> None: + """Loading state dict with classifier keys should drop them silently.""" + # ensure classifier keys are dropped before loading into the model + encoder = EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=3, + channel_multiplier=1.0, + depth_multiplier=1.0, + ) + extended_state = dict(encoder.state_dict()) + extended_state["classifier.bias"] = torch.tensor([1.0]) + extended_state["classifier.weight"] = torch.tensor([[1.0]]) + load_result = encoder.load_state_dict(extended_state, strict=True) + assert not load_result.missing_keys + assert not load_result.unexpected_keys + assert "classifier.bias" not in encoder.state_dict() + assert "classifier.weight" not in encoder.state_dict() diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 8ab9a998f..3a4ccab9b 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -815,7 +815,7 @@ mapde-crchisto: threshold_abs: 250 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -837,7 +837,7 @@ mapde-conic: threshold_abs: 205 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -860,7 +860,7 @@ sccnn-crchisto: threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.25 } @@ -883,7 +883,7 @@ sccnn-conic: threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.25 } @@ -903,7 +903,7 @@ nuclick_original-pannuke: num_input_channels: 5 num_output_channels: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -925,7 +925,7 @@ nuclick_light-pannuke: decoder_block: [3,3] skip_type: "add" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -934,3 +934,21 @@ nuclick_light-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + +grandqc_tissue_detection: + hf_repo_id: TIACentre/GrandQC_Tissue_Detection + architecture: + class: grandqc.GrandQCModel + kwargs: + num_output_channels: 2 + ioconfig: + class: io_config.IOSegmentorConfig + kwargs: + input_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + output_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + patch_input_shape: [512, 512] + patch_output_shape: [512, 512] + stride_shape: [256, 256] + save_resolution: {'units': 'mpp', 'resolution': 10.0} diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py new file mode 100644 index 000000000..e4e22740c --- /dev/null +++ b/tiatoolbox/models/architecture/grandqc.py @@ -0,0 +1,702 @@ +"""GrandQC Tissue Detection Model Architecture [1]. + +This module defines the GrandQC model for tissue detection in digital pathology. +It implements a UNet++ architecture with an EfficientNetB0 encoder and a segmentation +head for high-resolution tissue segmentation. The model is designed to identify +tissue regions and background areas for quality control in whole slide images (WSIs). +Please cite the paper [1], if you use this model. + +Key Components: +--------------- +- SegmentationHead: + Final layer for segmentation output. +- Conv2dReLU: + Convolutional block with BatchNorm and ReLU activation. +- DecoderBlock: + Decoder block with skip connections for feature fusion. +- CenterBlock: + Bottleneck block for deep feature processing. +- UnetPlusPlusDecoder: + Decoder with dense skip connections for UNet++ architecture. +- GrandQCModel: + Main model class implementing encoder-decoder architecture for tissue detection. + +Features: +--------- +- JPEG compression and ImageNet normalization during preprocessing. +- Argmin-based postprocessing for generating tissue masks. +- Efficient inference pipeline for batch processing. + +Example: + >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor + >>> segmentor = SemanticSegmentor(model="grandqc_tissue_detection_mpp10") + >>> results = segmentor.run( + ... ["/example_wsi.svs"], + ... masks=None, + ... auto_get_mask=False, + ... patch_mode=False, + ... save_dir=Path("/tissue_mask/"), + ... output_type="annotationstore", + ... ) + +References: + [1] Weng, Zhilong et al. "GrandQC: A comprehensive solution to quality control + problem in digital pathology." Nature Communications, 2024. + DOI: 10.1038/s41467-024-54769-y + URL: https://doi.org/10.1038/s41467-024-54769-y + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Sequence + +import cv2 +import numpy as np +import torch +from torch import nn + +from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder +from tiatoolbox.models.models_abc import ModelABC + + +class SegmentationHead(nn.Sequential): + """Segmentation head for UNet++ architecture. + + This class defines the final segmentation layer for the UNet++ model. + It applies a convolution followed by optional upsampling and activation + to produce the segmentation output. + + Attributes: + conv2d (nn.Conv2d): + Convolutional layer for feature transformation. + upsampling_layer (nn.Module): + Upsampling layer (bilinear interpolation or identity). + activation (nn.Module): + Activation function applied after upsampling. + + Example: + >>> head = SegmentationHead(in_channels=64, out_channels=2) + >>> x = torch.randn(1, 64, 128, 128) + >>> output = head(x) + >>> output.shape + ... torch.Size([1, 2, 128, 128]) + + """ + + def __init__( + self: SegmentationHead, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + activation: nn.Module | None = None, + upsampling: int = 1, + ) -> None: + """Initialize the SegmentationHead module. + + This method sets up the segmentation head by creating a convolutional layer, + an optional upsampling layer, and an activation function. It is typically + used as the final stage in UNet++ architectures for semantic segmentation. + + Args: + in_channels (int): + Number of input channels to the segmentation head. + out_channels (int): + Number of output channels (usually equal to the number of classes). + kernel_size (int): + Size of the convolution kernel. Defaults to 3. + activation (nn.Module | None): + Activation function applied after convolution. Defaults to None. + upsampling (int): + Upsampling factor applied to the output. Defaults to 1. + + Raises: + ValueError: + If `kernel_size` or `upsampling` is not a positive integer. + + """ + conv2d = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + upsampling_layer = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + if activation is None: + activation = nn.Identity() + super().__init__(conv2d, upsampling_layer, activation) + + +class Conv2dReLU(nn.Sequential): + """Conv2d + BatchNorm + ReLU block. + + This class implements a common convolutional block used in encoder-decoder + architectures. It consists of a 2D convolution followed by batch normalization + and a ReLU activation function. + + Attributes: + conv (nn.Conv2d): + Convolutional layer for feature extraction. + norm (nn.BatchNorm2d): + Batch normalization layer for stabilizing training. + activation (nn.ReLU): + ReLU activation function applied after normalization. + + Example: + >>> block = Conv2dReLU( + ... in_channels=32, out_channels=64, kernel_size=3, padding=1 + ... ) + >>> x = torch.randn(1, 32, 128, 128) + >>> output = block(x) + >>> output.shape + ... torch.Size([1, 64, 128, 128]) + + """ + + def __init__( + self: Conv2dReLU, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + *, + bias: bool = False, + ) -> None: + """Initialize Conv2dReLU block. + + Creates a convolutional layer followed by batch normalization and a ReLU + activation function. This block is commonly used in UNet++ and similar + architectures for feature extraction. + + Args: + in_channels (int): + Number of input channels. + out_channels (int): + Number of output channels. + kernel_size (int): + Size of the convolution kernel. + padding (int): + Padding applied to the input. Defaults to 0. + stride (int): + Stride of the convolution. Defaults to 1. + bias (bool): + If `True`, adds a learnable bias to the output. Default: `False` + + """ + norm = nn.BatchNorm2d(out_channels) + + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + activation = nn.ReLU(inplace=True) + + super().__init__(conv, norm, activation) + + +class DecoderBlock(nn.Module): + """Decoder block for UNet++ architecture. + + This block performs upsampling and feature fusion using skip connections + from the encoder. It consists of two convolutional layers with ReLU activation + and optional attention mechanisms (not implemented). + + Attributes: + conv1 (Conv2dReLU): + First convolutional block applied after concatenating input + and skip features. + conv2 (Conv2dReLU): + Second convolutional block for further refinement. + attention1 (nn.Module): + Attention mechanism applied before the first convolution + (currently Identity). + attention2 (nn.Module): + Attention mechanism applied after the second convolution + (currently Identity). + + Example: + >>> block = DecoderBlock(in_channels=128, skip_channels=64, out_channels=64) + >>> input_tensor = torch.randn(1, 128, 64, 64) + >>> skip = torch.randn(1, 64, 128, 128) + >>> output = block(input_tensor, skip) + >>> output.shape + ... torch.Size([1, 64, 128, 128]) + + """ + + def __init__( + self: DecoderBlock, + in_channels: int, + skip_channels: int, + out_channels: int, + ) -> None: + """Initialize DecoderBlock. + + Creates two convolutional layers and optional attention modules for + feature refinement during decoding. + + Args: + in_channels (int): + Number of input channels from the previous decoder layer. + skip_channels (int): + Number of channels from the skip connection. + out_channels (int): + Number of output channels for this block. + + """ + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention1 = nn.Identity() + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention2 = nn.Identity() + + def forward( + self: DecoderBlock, + input_tensor: torch.Tensor, + skip: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass through the decoder block. + + Upsamples the input tensor, concatenates it with the skip connection + (if provided), and applies two convolutional layers with attention. + + Args: + input_tensor (torch.Tensor): + (B, C_in, H, W). Input tensor from the previous decoder layer. + skip (torch.Tensor | None): + (B, C_skip, H*2, W*2). + Skip connection tensor from the encoder. Defaults to None. + + Returns: + torch.Tensor: + (B, C_out, H*2, W*2). + Output tensor after decoding and feature refinement. + + """ + input_tensor = torch.nn.functional.interpolate( + input_tensor, scale_factor=2.0, mode="nearest" + ) + if skip is not None: + input_tensor = torch.cat([input_tensor, skip], dim=1) + input_tensor = self.attention1(input_tensor) + input_tensor = self.conv1(input_tensor) + input_tensor = self.conv2(input_tensor) + return self.attention2(input_tensor) + + +class CenterBlock(nn.Sequential): + """Center block for UNet++ architecture. + + This block can be placed at the bottleneck of the UNet++ architecture. + It consists of two convolutional layers with ReLU activation, used + to process the deepest feature maps before decoding begins. + + Attributes: + conv1 (Conv2dReLU): + First convolutional block for feature transformation. + conv2 (Conv2dReLU): + Second convolutional block for further refinement. + + Example: + >>> center = CenterBlock(in_channels=256, out_channels=512) + >>> input_tensor = torch.randn(1, 256, 32, 32) + >>> output = center(input_tensor) + >>> output.shape + ... torch.Size([1, 512, 32, 32]) + + """ + + def __init__( + self: CenterBlock, + in_channels: int, + out_channels: int, + ) -> None: + """Initialize CenterBlock. + + Creates two convolutional layers with batch normalization and ReLU + activation for processing the deepest encoder features. + + Args: + in_channels (int): + Number of input channels from the encoder. + out_channels (int): + Number of output channels for the center block. + + """ + conv1 = Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + """UNet++ decoder with dense skip connections. + + This class implements the decoder portion of the UNet++ architecture. + It reconstructs high-resolution feature maps from encoder outputs using + multiple decoder blocks and dense connections between intermediate layers. + + Raises: + ValueError: + If the number of decoder blocks does not match the length of + `decoder_channels`. + + Attributes: + blocks (nn.ModuleDict): + Dictionary of decoder blocks organized by depth and layer index. + center (nn.Module): + Center block (currently Identity). + depth (int): + Depth of the decoder network. + + Example: + >>> decoder = UnetPlusPlusDecoder( + ... encoder_channels=[3, 32, 64, 128, 256, 512], + ... decoder_channels=[256, 128, 64, 32, 16], + ... n_blocks=5 + ... ) + >>> # Generate dummy feature maps for testing + >>> features = [ + ... torch.randn(1, c, 64 // (2**i), 64 // (2**i)) + ... for i, c in enumerate([3, 32, 64, 128, 256, 512]) + ... ] + >>> output = decoder(features) + >>> output.shape + ... torch.Size([1, 16, 64, 64]) + + """ + + def __init__( + self, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + ) -> None: + """Initialize UnetPlusPlusDecoder. + + Sets up the decoder blocks and dense connections for UNet++ architecture. + + Args: + encoder_channels (Sequence[int]): + List of channel sizes from the encoder stages. + decoder_channels (Sequence[int]): + List of channel sizes for each decoder block. + n_blocks (int): + Number of decoder blocks. Defaults to 5. + + Raises: + ValueError: + If `n_blocks` does not match the length of `decoder_channels`. + + """ + super().__init__() + + if n_blocks != len(decoder_channels): + msg = ( + f"Model depth is {n_blocks}, but you provide " + f"`decoder_channels` for {len(decoder_channels)} blocks." + ) + raise ValueError(msg) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels, *list(decoder_channels[:-1])] + self.skip_channels = [*list(encoder_channels[1:]), 0] + self.out_channels = decoder_channels + + self.center = nn.Identity() + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx + 1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * ( + layer_idx + 1 - depth_idx + ) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( + in_ch, skip_ch, out_ch + ) + blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( + self.in_channels[-1], 0, self.out_channels[-1] + ) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Forward pass through UNet++ decoder. + + Reconstructs high-resolution feature maps from encoder outputs using + dense skip connections and multiple decoder blocks. + + Args: + features (list[torch.Tensor]): + List of feature maps from the encoder, ordered from shallow to deep. + + Returns: + torch.Tensor: + Decoded output tensor with spatial resolution restored. + + """ + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + # start building dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(self.depth - layer_idx): + if layer_idx == 0: + output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( + features[depth_idx], features[depth_idx + 1] + ) + dense_x[f"x_{depth_idx}_{depth_idx}"] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [ + dense_x[f"x_{idx}_{dense_l_i}"] + for idx in range(depth_idx + 1, dense_l_i + 1) + ] + cat_features = torch.cat( + [*cat_features, features[dense_l_i + 1]], dim=1 + ) + dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ + f"x_{depth_idx}_{dense_l_i}" + ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) + dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( + dense_x[f"x_{0}_{self.depth - 1}"] + ) + return dense_x[f"x_{0}_{self.depth}"] + + +class GrandQCModel(ModelABC): + """GrandQC Tissue Detection Model. + + This model implements a UNet++ architecture with an EfficientNet encoder + for tissue detection in whole slide images (WSIs). It is designed to + identify tissue regions and background areas for quality control in + digital pathology workflows. + + The model uses JPEG compression and ImageNet normalization during + preprocessing and applies argmin-based postprocessing to generate + tissue masks. + + Example: + >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor + >>> segmentor = SemanticSegmentor(model="grandqc_tissue_detection") + >>> results = segmentor.run( + ... ["/example_wsi.svs"], + ... masks=None, + ... auto_get_mask=False, + ... patch_mode=False, + ... save_dir=Path("/tissue_mask/"), + ... output_type="annotationstore", + ... ) + + References: + [1] Weng, Zhilong et al. "GrandQC: A comprehensive solution to quality control + problem in digital pathology." Nature Communications, 2024. + DOI: 10.1038/s41467-024-54769-y + URL: https://doi.org/10.1038/s41467-024-54769-y + + """ + + def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: + """Initialize GrandQCModel. + + Sets up the UNet++ decoder, EfficientNet encoder, and segmentation head + for tissue detection. + + Args: + num_output_channels (int): + Number of output classes. Defaults to 2 (Tissue and Background). + + """ + super().__init__() + self.num_output_channels = num_output_channels + self.decoder_channels = (256, 128, 64, 32, 16) + + self.encoder = EfficientNetEncoder( + out_channels=[3, 32, 24, 40, 112, 320], + stage_idxs=[2, 3, 5], + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, + ) + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=self.decoder_channels, + n_blocks=5, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder_channels[-1], + out_channels=num_output_channels, + kernel_size=3, + ) + + self.name = "unetplusplus-efficientnetb0" + + def forward( # skipcq: PYL-W0613 + self: GrandQCModel, + x: torch.Tensor, + *args: tuple[Any, ...], # noqa: ARG002 + **kwargs: dict, # noqa: ARG002 + ) -> torch.Tensor: + """Forward pass through the GrandQC model. + + Sequentially processes the input tensor through the encoder, decoder, + and segmentation head to produce tissue segmentation predictions. + + Args: + x (torch.Tensor): + Input tensor of shape (N, C, H, W). + *args (tuple): + Additional positional arguments (unused). + **kwargs (dict): + Additional keyword arguments (unused). + + Returns: + torch.Tensor: + Segmentation output tensor of shape (N, num_classes, H, W). + + """ + features = self.encoder(x) + decoder_output = self.decoder(features) + + return self.segmentation_head(decoder_output) + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Preprocess input image for inference. + + Applies JPEG compression and ImageNet normalization to the input image. + + Args: + image (np.ndarray): + Input image as a NumPy array of shape (H, W, C) in uint8 format. + + Returns: + np.ndarray: + Preprocessed image normalized to ImageNet statistics. + + Example: + >>> img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) + >>> processed = GrandQCModel.preproc(img) + >>> processed.shape + ... (256, 256, 3) + + """ + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] + _, compressed_image = cv2.imencode(".jpg", image, encode_param) + compressed_image = np.array(cv2.imdecode(compressed_image, 1)) + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + return (compressed_image / 255.0 - mean) / std + + @staticmethod + def postproc(image: np.ndarray) -> np.ndarray: + """Postprocess model output to generate tissue mask. + + Applies argmin across channels to classify pixels as tissue or background. + + Args: + image (np.ndarray): + Input probability map as a NumPy array of shape (H, W, C). + + Returns: + np.ndarray: + Binary tissue mask where 0 = Tissue and 1 = Background. + + Example: + >>> probs = np.random.rand(256, 256, 2) + >>> mask = GrandQCModel.postproc(probs) + >>> mask.shape + ... (256, 256) + + """ + return image.argmin(axis=-1) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: torch.Tensor, + *, + device: str, + ) -> np.ndarray: + """Run inference on a batch of images. + + Transfers the model and input batch to the specified device, performs + forward pass, and returns softmax probabilities. + + Args: + model (torch.nn.Module): + PyTorch model instance. + batch_data (torch.Tensor): + Batch of input images in NHWC format. + device (str): + Device for inference (e.g., "cpu" or "cuda"). + + Returns: + np.ndarray: + Inference results as a NumPy array of shape (N, H, W, C). + + Example: + >>> batch = torch.randn(4, 256, 256, 3) + >>> probs = GrandQCModel.infer_batch(model, batch, device="cpu") + >>> probs.shape + (4, 256, 256, 2) + + """ + model = model.to(device) + model.eval() + + imgs = batch_data + imgs = imgs.to(device).type(torch.float32) + imgs = imgs.permute(0, 3, 1, 2) # to NCHW + + with torch.inference_mode(): + logits = model(imgs) + probs = torch.nn.functional.softmax(logits, 1) + probs = probs.permute(0, 2, 3, 1) # to NHWC + + return probs.cpu().numpy() diff --git a/tiatoolbox/models/architecture/timm_efficientnet.py b/tiatoolbox/models/architecture/timm_efficientnet.py new file mode 100644 index 000000000..d63c64461 --- /dev/null +++ b/tiatoolbox/models/architecture/timm_efficientnet.py @@ -0,0 +1,650 @@ +"""EfficientNet Encoder Implementation using timm. + +This module provides an implementation of EfficientNet-based encoders for use in +semantic segmentation and other computer vision tasks. It leverages the `timm` +library for model components and adds encoder-specific functionality such as +custom input channels, dilation support, and configurable scaling parameters. + +Key Components: +--------------- +- patch_first_conv: + Utility to modify the first convolution layer for arbitrary input channels. +- replace_strides_with_dilation: + Utility to convert strides into dilations for atrous convolutions. +- EncoderMixin: + Mixin class adding encoder-specific features like output channels and stride. +- EfficientNetBaseEncoder: + Base encoder combining EfficientNet backbone with encoder functionality. +- EfficientNetEncoder: + Configurable EfficientNet encoder supporting depth and channel scaling. +- timm_efficientnet_encoders: + Dictionary of available EfficientNet encoder configurations and pretrained settings. + +Features: +--------- +- Supports arbitrary input channels (e.g., grayscale or multi-channel images). +- Allows conversion to dilated versions for semantic segmentation. +- Provides pretrained weights from multiple sources (ImageNet, AdvProp, Noisy Student). +- Implements scaling rules for EfficientNet architecture. + +Example: + >>> from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder + >>> encoder = EfficientNetEncoder( + ... stage_idxs=[2, 3, 5], + ... out_channels=[3, 32, 24, 40, 112, 320], + ... channel_multiplier=1.0, + ... depth_multiplier=1.0, + ... drop_rate=0.2 + ... ) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = encoder(x) + >>> [f.shape for f in features] + [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...] + +References: + - Tan, Mingxing, and Quoc V. Le. "EfficientNet: Rethinking Model Scaling for + Convolutional Neural Networks." arXiv preprint arXiv:1905.11946 (2019). + URL: https://arxiv.org/abs/1905.11946 + +""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Mapping, Sequence + +import torch +from timm.layers.activations import Swish +from timm.models._efficientnet_builder import decode_arch_def, round_channels +from timm.models.efficientnet import EfficientNet +from torch import nn + +MAX_DEPTH = 5 +MIN_DEPTH = 1 +DEFAULT_IN_CHANNELS = 3 + + +def patch_first_conv( + model: nn.Module, + new_in_channels: int, + default_in_channels: int = 3, + *, + pretrained: bool = True, +) -> None: + """Update the first convolution layer for a new input channel size. + + This function updates the first convolutional layer of a model to handle + arbitrary input channels. It optionally reuses pretrained weights or + initializes weights randomly. + + Args: + model (nn.Module): + The neural network model whose first convolution layer will be patched. + new_in_channels (int): + Number of input channels for the new first layer. + default_in_channels (int): + Original number of input channels. Defaults to 3. + pretrained (bool): + Whether to reuse pretrained weights. Defaults to True. + + Notes: + - If `new_in_channels` == 1 or 2 → reuse original weights. + - If `new_in_channels` > 3 → initialize weights using Kaiming normal. + + Example: + >>> patch_first_conv(model, new_in_channels=1, pretrained=True) + + """ + # get first conv + conv_module: nn.Conv2d | None = None + for module in model.modules(): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: + conv_module = module + break + + if conv_module is None: + return + + weight = conv_module.weight.detach() + conv_module.in_channels = new_in_channels + + if not pretrained: + conv_module.weight = nn.parameter.Parameter( + torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + ) + conv_module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + conv_module.weight = nn.parameter.Parameter(new_weight) + + else: + new_weight = torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + conv_module.weight = nn.parameter.Parameter(new_weight) + + +def replace_strides_with_dilation(module: nn.Module, dilation_rate: int) -> None: + """Replace strides with dilation in Conv2d layers. + + Converts convolutional layers to use dilation instead of stride, enabling + atrous convolutions for semantic segmentation tasks. + + Args: + module (nn.Module): + Module containing Conv2d layers to patch. + dilation_rate (int): + Dilation rate to apply to all Conv2d layers. + + Example: + >>> replace_strides_with_dilation(model, dilation_rate=2) + + """ + for mod in module.modules(): + if isinstance(mod, nn.Conv2d): + mod.stride = (1, 1) + mod.dilation = (dilation_rate, dilation_rate) + kh, _ = mod.kernel_size + mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) + + # Workaround for EfficientNet + if hasattr(mod, "static_padding"): + mod.static_padding = nn.Identity() # type: ignore[attr-defined] + + +class EncoderMixin: + """Mixin class adding encoder-specific functionality. + + Provides methods for: + - Managing output channels for encoder feature maps. + - Patching the first convolution for arbitrary input channels. + - Converting encoder to dilated version for segmentation tasks. + + Attributes: + _depth (int): + Encoder depth (number of stages). + _in_channels (int): + Number of input channels. + _output_stride (int): + Output stride of the encoder. + _out_channels (list[int]): + List of output channel dimensions for each depth level. + + Example: + >>> encoder = EncoderMixin() + >>> encoder.set_in_channels(1) + + """ + + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__(self) -> None: + """Initialize EncoderMixin with default parameters. + + Sets default values for encoder depth, input channels, output stride, + and output channel list. + + """ + self._depth = 5 + self._in_channels = 3 + self._output_stride = 32 + self._out_channels: list[int] = [] + + @property + def out_channels(self) -> list[int]: + """Return output channel dimensions for encoder feature maps. + + Returns: + list[int]: + List of output channel dimensions for each depth level. + + Example: + >>> encoder.out_channels + ... [3, 32, 64, 128, 256, 512] + + """ + return self._out_channels[: self._depth + 1] + + @property + def output_stride(self) -> int: + """Return the effective output stride of the encoder. + + The output stride is the minimum of the configured stride and 2^depth. + + Returns: + int: + Effective output stride. + + Example: + >>> encoder.output_stride + ... 32 + + """ + return min(self._output_stride, 2**self._depth) + + def set_in_channels(self, in_channels: int, *, pretrained: bool = True) -> None: + """Update the encoder to accept a different number of input channels. + + Args: + in_channels (int): + Number of input channels. + pretrained (bool): + Whether to use pretrained weights. Defaults to True. + + Example: + >>> encoder.set_in_channels(1, pretrained=False) + + """ + if in_channels == DEFAULT_IN_CHANNELS: + return + + self._in_channels = in_channels + if self._out_channels[0] == DEFAULT_IN_CHANNELS: + self._out_channels = [in_channels, *self._out_channels[1:]] + # Type ignore needed because self is a mixin that will be used with nn.Module + patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) # type: ignore[arg-type] + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Return encoder stages for dilation modification. + + This method should be overridden by subclasses to provide stage mappings + for converting strides to dilations. + + Returns: + dict[int, Sequence[torch.nn.Module]]: + Dictionary mapping output stride to corresponding module sequences. + + Raises: + NotImplementedError: + If the method is not implemented by the subclass. + + Example: + >>> stages = encoder.get_stages() + + """ + raise NotImplementedError + + def make_dilated(self, output_stride: int) -> None: + """Convert encoder to a dilated version for segmentation. + + Args: + output_stride (int): + Target output stride (must be 8 or 16). + + Raises: + ValueError: + If `output_stride` is not 8 or 16. + + Example: + >>> encoder.make_dilated(output_stride=16) + + """ + if output_stride not in [8, 16]: + msg = f"Output stride should be 16 or 8, got {output_stride}." + raise ValueError(msg) + + stages = self.get_stages() + for stage_stride, stage_modules in stages.items(): + if stage_stride <= output_stride: + continue + + dilation_rate = stage_stride // output_stride + for module in stage_modules: + replace_strides_with_dilation(module, dilation_rate) + + +def get_efficientnet_kwargs( + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, +) -> dict[str, Any]: + """Generate configuration parameters for EfficientNet. + + Args: + channel_multiplier (float): + Multiplier for number of channels per layer. Defaults to 1.0. + depth_multiplier (float): + Multiplier for number of repeats per stage. Defaults to 1.0. + drop_rate (float): + Dropout rate. Defaults to 0.2. + + Reference implementation: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + + Paper: + https://arxiv.org/abs/1905.11946 + + EfficientNet parameters: + - 'efficientnet-b0': (1.0, 1.0, 224, 0.2) + - 'efficientnet-b1': (1.0, 1.1, 240, 0.2) + - 'efficientnet-b2': (1.1, 1.2, 260, 0.3) + - 'efficientnet-b3': (1.2, 1.4, 300, 0.3) + - 'efficientnet-b4': (1.4, 1.8, 380, 0.4) + - 'efficientnet-b5': (1.6, 2.2, 456, 0.4) + - 'efficientnet-b6': (1.8, 2.6, 528, 0.5) + - 'efficientnet-b7': (2.0, 3.1, 600, 0.5) + - 'efficientnet-b8': (2.2, 3.6, 672, 0.5) + - 'efficientnet-l2': (4.3, 5.3, 800, 0.5) + + Args: + channel_multiplier: Multiplier to number of channels per layer. Defaults to 1.0. + depth_multiplier: Multiplier to number of repeats per stage. Defaults to 1.0. + drop_rate: Dropout rate. Defaults to 0.2. + + + Returns: + dict[str, Any]: + Dictionary containing EfficientNet configuration parameters + + Example: + >>> kwargs = get_efficientnet_kwargs( + ... channel_multiplier=1.2, + ... depth_multiplier=1.4, + ... ) + + """ + arch_def = [ + ["ds_r1_k3_s1_e1_c16_se0.25"], + ["ir_r2_k3_s2_e6_c24_se0.25"], + ["ir_r2_k5_s2_e6_c40_se0.25"], + ["ir_r3_k3_s2_e6_c80_se0.25"], + ["ir_r3_k5_s1_e6_c112_se0.25"], + ["ir_r4_k5_s2_e6_c192_se0.25"], + ["ir_r1_k3_s1_e6_c320_se0.25"], + ] + return { + "block_args": decode_arch_def(arch_def, depth_multiplier), + "num_features": round_channels(1280, channel_multiplier, 8, None), + "stem_size": 32, + "round_chs_fn": partial(round_channels, multiplier=channel_multiplier), + "act_layer": Swish, + "drop_rate": drop_rate, + "drop_path_rate": 0.2, + } + + +class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): + """Base class for EfficientNet encoder. + + Combines EfficientNet backbone from `timm` with encoder-specific functionality + for feature extraction in segmentation and classification tasks. + + Features: + - Supports configurable depth and output stride. + - Provides intermediate feature maps for multi-scale processing. + - Removes classifier for encoder-only usage. + + Raises: + ValueError: + If `depth` is not in range [1, 5]. + + Example: + >>> encoder = EfficientNetBaseEncoder( + ... stage_idxs=[2, 3, 5], + ... out_channels=[3, 32, 24, 40, 112, 320], + ... depth=5, + ... output_stride=32 + ... ) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = encoder(x) + >>> [f.shape for f in features] + ... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...] + + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ) -> None: + """Initialize EfficientNetBaseEncoder. + + Args: + stage_idxs (list[int]): + Indices of stages for feature extraction. + out_channels (list[int]): + Output channels for each depth level. + depth (int): + Encoder depth (1-5). Defaults to 5. + output_stride (int): + Output stride of encoder. Defaults to 32. + **kwargs (dict[str, Any]): + Additional keyword arguments for EfficientNet initialization. + + Raises: + ValueError: + If `depth` is not in range [1, 5]. + + """ + if depth > MAX_DEPTH or depth < MIN_DEPTH: + msg = f"{self.__class__.__name__} depth should be in range \ + [1, 5], got {depth}" + raise ValueError(msg) + super().__init__(**kwargs) + + self._stage_idxs = stage_idxs + self._depth = depth + self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + + del self.classifier + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Return encoder stages for dilation modification. + + Provides mapping of output strides to corresponding module sequences, + enabling conversion to dilated versions for segmentation tasks. + + Returns: + dict[int, Sequence[torch.nn.Module]]: + Dictionary mapping output stride to module sequences. + + Example: + >>> stages = encoder.get_stages() + >>> print(stages.keys()) + ... dict_keys([16, 32]) + + """ + return { + 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], # type: ignore[attr-defined] + 32: [self.blocks[self._stage_idxs[2] :]], # type: ignore[attr-defined] + } + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """Forward pass through EfficientNet encoder. + + Extracts feature maps from multiple stages of the encoder for use in + decoder networks or multi-scale processing. + + Args: + x (torch.Tensor): + Input tensor of shape (N, C, H, W). + + Returns: + list[torch.Tensor]: + List of feature maps from different encoder depths. + + Example: + >>> x = torch.randn(1, 3, 224, 224) + >>> features = encoder(x) + >>> len(features) + ... 6 + + """ + features = [x] + + if self._depth >= 1: + x = self.conv_stem(x) # type: ignore[attr-defined] + x = self.bn1(x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 2: # noqa: PLR2004 + x = self.blocks[0](x) # type: ignore[attr-defined] + x = self.blocks[1](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 3: # noqa: PLR2004 + x = self.blocks[2](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 4: # noqa: PLR2004 + x = self.blocks[3](x) # type: ignore[attr-defined] + x = self.blocks[4](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 5: # noqa: PLR2004 + x = self.blocks[5](x) # type: ignore[attr-defined] + x = self.blocks[6](x) # type: ignore[attr-defined] + features.append(x) + + return features + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + **kwargs: bool, + ) -> torch.nn.modules.module._IncompatibleKeys: + """Load state dictionary, excluding classifier weights. + + Removes classifier weights from the state dictionary before loading, + as the encoder does not include a classification head. + + Args: + state_dict (Mapping[str, Any]): + State dictionary to load. + **kwargs (bool): + Additional keyword arguments for `load_state_dict`. + + Returns: + torch.nn.modules.module._IncompatibleKeys: + Result of parent class `load_state_dict` method. + + Example: + >>> encoder.load_state_dict(torch.load("efficientnet_weights.pth")) + + """ + # Create a mutable copy of the state dict to modify + state_dict_copy = dict(state_dict) + state_dict_copy.pop("classifier.bias", None) + state_dict_copy.pop("classifier.weight", None) + return super().load_state_dict(state_dict_copy, **kwargs) + + +class EfficientNetEncoder(EfficientNetBaseEncoder): + """EfficientNet encoder with configurable scaling parameters. + + This class extends `EfficientNetBaseEncoder` to provide scaling options + for depth and channel multipliers, enabling flexible encoder configurations + for segmentation and classification tasks. + + Features: + - Supports depth and channel scaling. + - Provides pretrained weights for multiple variants. + - Outputs multi-scale feature maps for downstream tasks. + + Example: + >>> encoder = EfficientNetEncoder( + ... stage_idxs=[2, 3, 5], + ... out_channels=[3, 32, 24, 40, 112, 320], + ... channel_multiplier=1.0, + ... depth_multiplier=1.0, + ... drop_rate=0.2 + ... ) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = encoder(x) + >>> [f.shape for f in features] + ... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...] + + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, + ) -> None: + """Initialize EfficientNetEncoder. + + Creates an EfficientNet encoder with configurable scaling parameters + for depth and channel multipliers. + + Args: + stage_idxs (list[int]): + Indices of stages for feature extraction. + out_channels (list[int]): + Output channels for each depth level. + depth (int): + Encoder depth (1-5). Defaults to 5. + channel_multiplier (float): + Channel scaling factor. Defaults to 1.0. + depth_multiplier (float): + Depth scaling factor. Defaults to 1.0. + drop_rate (float): + Dropout rate. Defaults to 0.2. + output_stride (int): + Output stride of encoder. Defaults to 32. + + """ + kwargs = get_efficientnet_kwargs( + channel_multiplier, depth_multiplier, drop_rate + ) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) + + +timm_efficientnet_encoders = { + "timm-efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b0.imagenet", + "revision": "8419e9cc19da0b68dcd7bb12f19b7c92407ad7c4", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b0.advprop", + "revision": "a5870af2d24ce79e0cc7fae2bbd8e0a21fcfa6d8", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b0.noisy-student", + "revision": "bea8b0ff726a50e48774d2d360c5fb1ac4815836", + }, + }, + "params": { + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, +}