|
| 1 | +from typing import Optional, Union, List |
| 2 | +from .decoder import UnetPlusPlusDecoder |
| 3 | +from ..encoders import get_encoder |
| 4 | +from ..base import SegmentationModel |
| 5 | +from ..base import SegmentationHead, ClassificationHead |
| 6 | + |
| 7 | + |
| 8 | +class UnetPlusPlus(SegmentationModel): |
| 9 | + """Unet++_ is a fully convolution neural network for image semantic segmentation |
| 10 | +
|
| 11 | + Args: |
| 12 | + encoder_name: name of classification model (without last dense layers) used as feature |
| 13 | + extractor to build segmentation model. |
| 14 | + encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. |
| 15 | + e.g. for depth=3 encoder will generate list of features with following spatial shapes |
| 16 | + [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have |
| 17 | + spatial resolution (H/(2^depth), W/(2^depth)] |
| 18 | + encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). |
| 19 | + decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks |
| 20 | + decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers |
| 21 | + is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. |
| 22 | + One of [True, False, 'inplace'] |
| 23 | + decoder_attention_type: attention module used in decoder of the model |
| 24 | + One of [``None``, ``scse``] |
| 25 | + in_channels: number of input channels for model, default is 3. |
| 26 | + classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). |
| 27 | + activation: activation function to apply after final convolution; |
| 28 | + One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] |
| 29 | + aux_params: if specified model will have additional classification auxiliary output |
| 30 | + build on top of encoder, supported params: |
| 31 | + - classes (int): number of classes |
| 32 | + - pooling (str): one of 'max', 'avg'. Default is 'avg'. |
| 33 | + - dropout (float): dropout factor in [0, 1) |
| 34 | + - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) |
| 35 | +
|
| 36 | + Returns: |
| 37 | + ``torch.nn.Module``: **Unet++** |
| 38 | +
|
| 39 | + .. _UnetPlusPlus: |
| 40 | + https://arxiv.org/pdf/1807.10165.pdf |
| 41 | +
|
| 42 | + """ |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + encoder_name: str = "resnet34", |
| 47 | + encoder_depth: int = 5, |
| 48 | + encoder_weights: str = "imagenet", |
| 49 | + decoder_use_batchnorm: bool = True, |
| 50 | + decoder_channels: List[int] = (256, 128, 64, 32, 16), |
| 51 | + decoder_attention_type: Optional[str] = None, |
| 52 | + in_channels: int = 3, |
| 53 | + classes: int = 1, |
| 54 | + activation: Optional[Union[str, callable]] = None, |
| 55 | + aux_params: Optional[dict] = None, |
| 56 | + ): |
| 57 | + super().__init__() |
| 58 | + |
| 59 | + self.encoder = get_encoder( |
| 60 | + encoder_name, |
| 61 | + in_channels=in_channels, |
| 62 | + depth=encoder_depth, |
| 63 | + weights=encoder_weights, |
| 64 | + ) |
| 65 | + |
| 66 | + self.decoder = UnetPlusPlusDecoder( |
| 67 | + encoder_channels=self.encoder.out_channels, |
| 68 | + decoder_channels=decoder_channels, |
| 69 | + n_blocks=encoder_depth, |
| 70 | + use_batchnorm=decoder_use_batchnorm, |
| 71 | + center=True if encoder_name.startswith("vgg") else False, |
| 72 | + attention_type=decoder_attention_type, |
| 73 | + ) |
| 74 | + |
| 75 | + self.segmentation_head = SegmentationHead( |
| 76 | + in_channels=decoder_channels[-1], |
| 77 | + out_channels=classes, |
| 78 | + activation=activation, |
| 79 | + kernel_size=3, |
| 80 | + ) |
| 81 | + |
| 82 | + if aux_params is not None: |
| 83 | + self.classification_head = ClassificationHead( |
| 84 | + in_channels=self.encoder.out_channels[-1], **aux_params |
| 85 | + ) |
| 86 | + else: |
| 87 | + self.classification_head = None |
| 88 | + |
| 89 | + self.name = "u-{}".format(encoder_name) |
| 90 | + self.initialize() |
0 commit comments