|
2 | 2 | import torch |
3 | 3 | import segmentation_models_pytorch as smp |
4 | 4 | import segmentation_models_pytorch.losses._functional as F |
5 | | -from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss |
| 5 | +from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss, \ |
| 6 | + TverskyLoss, TverskyLossFocal |
6 | 7 |
|
7 | 8 |
|
8 | 9 | def test_focal_loss_with_logits(): |
@@ -71,6 +72,21 @@ def test_soft_dice_score(y_true, y_pred, expected, eps): |
71 | 72 | assert float(actual) == pytest.approx(expected, eps) |
72 | 73 |
|
73 | 74 |
|
| 75 | +@pytest.mark.parametrize( |
| 76 | + ["y_true", "y_pred", "expected", "eps", "alpha", "beta"], |
| 77 | + [ |
| 78 | + [[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5], |
| 79 | + [[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5], |
| 80 | + [[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5], |
| 81 | + ], |
| 82 | +) |
| 83 | +def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): |
| 84 | + y_true = torch.tensor(y_true, dtype=torch.float32) |
| 85 | + y_pred = torch.tensor(y_pred, dtype=torch.float32) |
| 86 | + actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta) |
| 87 | + assert float(actual) == pytest.approx(expected, eps) |
| 88 | + |
| 89 | + |
74 | 90 | @torch.no_grad() |
75 | 91 | def test_dice_loss_binary(): |
76 | 92 | eps = 1e-5 |
@@ -109,6 +125,45 @@ def test_dice_loss_binary(): |
109 | 125 | assert float(loss) == pytest.approx(1.0, abs=eps) |
110 | 126 |
|
111 | 127 |
|
| 128 | +@torch.no_grad() |
| 129 | +def test_tversky_loss_binary(): |
| 130 | + eps = 1e-5 |
| 131 | + # with alpha=0.5; beta=0.5 it is equal to DiceLoss |
| 132 | + criterion = TverskyLoss(mode=smp.losses.BINARY_MODE, from_logits=False, alpha=0.5, beta=0.5) |
| 133 | + |
| 134 | + # Ideal case |
| 135 | + y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1) |
| 136 | + y_true = torch.tensor(([1, 1, 1])).view(1, 1, 1, -1) |
| 137 | + loss = criterion(y_pred, y_true) |
| 138 | + assert float(loss) == pytest.approx(0.0, abs=eps) |
| 139 | + |
| 140 | + y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) |
| 141 | + y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) |
| 142 | + loss = criterion(y_pred, y_true) |
| 143 | + assert float(loss) == pytest.approx(0.0, abs=eps) |
| 144 | + |
| 145 | + y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) |
| 146 | + y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) |
| 147 | + loss = criterion(y_pred, y_true) |
| 148 | + assert float(loss) == pytest.approx(0.0, abs=eps) |
| 149 | + |
| 150 | + # Worst case |
| 151 | + y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) |
| 152 | + y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) |
| 153 | + loss = criterion(y_pred, y_true) |
| 154 | + assert float(loss) == pytest.approx(0.0, abs=eps) |
| 155 | + |
| 156 | + y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) |
| 157 | + y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) |
| 158 | + loss = criterion(y_pred, y_true) |
| 159 | + assert float(loss) == pytest.approx(1.0, abs=eps) |
| 160 | + |
| 161 | + y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) |
| 162 | + y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) |
| 163 | + loss = criterion(y_pred, y_true) |
| 164 | + assert float(loss) == pytest.approx(1.0, abs=eps) |
| 165 | + |
| 166 | + |
112 | 167 | @torch.no_grad() |
113 | 168 | def test_binary_jaccard_loss(): |
114 | 169 | eps = 1e-5 |
|
0 commit comments