Skip to content

Commit c2766f0

Browse files
committed
feat: add more robust class setting + tests
1 parent 0d84868 commit c2766f0

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

src/mpl_image_segmenter/_segmenter.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from numbers import Integral
34
from typing import TYPE_CHECKING
45

56
import numpy as np
@@ -21,7 +22,7 @@ class ImageSegmenter:
2122
def __init__( # type: ignore
2223
self,
2324
img,
24-
nclasses=1,
25+
classes=1,
2526
mask=None,
2627
mask_colors=None,
2728
mask_alpha=0.75,
@@ -39,8 +40,8 @@ def __init__( # type: ignore
3940
----------
4041
img : array_like
4142
A valid argument to imshow
42-
nclasses : int, default 1
43-
How many classes to have in the mask.
43+
classes : int, iterable[string], default 1
44+
If a number How many classes to have in the mask.
4445
mask : arraylike, optional
4546
If you want to pre-seed the mask
4647
mask_colors : None, color, or array of colors, optional
@@ -69,13 +70,20 @@ def __init__( # type: ignore
6970

7071
self.mask_alpha = mask_alpha
7172

73+
if isinstance(classes, Integral):
74+
self._classes: list[str | int] = list(range(classes))
75+
else:
76+
self._classes = classes
77+
self._n_classes = len(self._classes)
7278
if mask_colors is None:
73-
# this will break if there are more than 10 classes
74-
if nclasses <= 10:
75-
self.mask_colors = to_rgba_array(list(TABLEAU_COLORS)[:nclasses])
79+
if self._n_classes <= 10:
80+
# There are only 10 tableau colors
81+
self.mask_colors = to_rgba_array(
82+
list(TABLEAU_COLORS)[: self._n_classes]
83+
)
7684
else:
7785
# up to 949 classes. Hopefully that is always enough....
78-
self.mask_colors = to_rgba_array(list(XKCD_COLORS)[:nclasses])
86+
self.mask_colors = to_rgba_array(list(XKCD_COLORS)[: self._n_classes])
7987
else:
8088
self.mask_colors = to_rgba_array(np.atleast_1d(mask_colors))
8189
# should probably check the shape here
@@ -90,8 +98,7 @@ def __init__( # type: ignore
9098
self.mask = mask
9199

92100
self._overlay = np.zeros((*self._img.shape[:2], 4))
93-
self.nclasses = nclasses
94-
for i in range(nclasses + 1):
101+
for i in range(self._n_classes + 1):
95102
idx = self.mask == i
96103
if i == 0:
97104
self._overlay[idx] = [0, 0, 0, 0]
@@ -160,6 +167,26 @@ def erasing(self, val: bool) -> None:
160167
raise TypeError(f"Erasing must be a bool - got type {type(val)}")
161168
self._erasing = val
162169

170+
@property
171+
def current_class(self) -> int | str:
172+
return self._classes[self._cur_class_idx - 1]
173+
174+
@current_class.setter
175+
def current_class(self, val: int | str) -> None:
176+
if isinstance(val, str):
177+
if val not in self._classes:
178+
raise ValueError(f"{val} is not one of the classes: {self._classes}")
179+
# offset by one for the background
180+
self._cur_class_idx = self._classes.index(val) + 1
181+
elif isinstance(val, Integral):
182+
if 0 < val < self._n_classes + 1:
183+
self._cur_class_idx = val
184+
else:
185+
raise ValueError(
186+
f"Current class must be bewteen 1 and {self._n_classes}."
187+
" It cannot be 0 as 0 is the background."
188+
)
189+
163190
def get_paths(self) -> dict[str, list[Path]]:
164191
"""
165192
Get a dictionary of all the paths used to create the mask.
@@ -179,8 +206,8 @@ def _onselect(self, verts: Any) -> None:
179206
self._overlay[self.indices] = [0, 0, 0, 0]
180207
self._paths["erasing"].append(p)
181208
else:
182-
self.mask[self.indices] = self.current_class
183-
self._overlay[self.indices] = self.mask_colors[self.current_class - 1]
209+
self.mask[self.indices] = self._cur_class_idx
210+
self._overlay[self.indices] = self.mask_colors[self._cur_class_idx - 1]
184211
self._paths["adding"].append(p)
185212

186213
self._mask.set_data(self._overlay)

tests/test_mpl_image_segmenter.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,41 @@
1-
def test_something():
2-
pass
1+
import re
2+
3+
import numpy as np
4+
import pytest
5+
from mpl_image_segmenter import ImageSegmenter
6+
7+
8+
def test_current_class():
9+
img = np.zeros([128, 128])
10+
seg = ImageSegmenter(img, classes=["a", "b", "c"])
11+
12+
top = 25
13+
left = 25
14+
right = 100
15+
bottom = 100
16+
seg.current_class = "a"
17+
seg._onselect([(left, top), (left, bottom), (right, bottom), (right, top)])
18+
seg.current_class = "b"
19+
top = 30
20+
seg._onselect([(left, top), (left, bottom), (right, bottom), (right, top)])
21+
seg.current_class = "c"
22+
top = 40
23+
seg._onselect([(left, top), (left, bottom), (right, bottom), (right, top)])
24+
25+
assert seg.mask.sum() == 15170
26+
27+
with pytest.raises(
28+
ValueError, match=re.escape("d is not one of the classes: ['a', 'b', 'c']")
29+
):
30+
seg.current_class = "d"
31+
assert seg._cur_class_idx == 3
32+
33+
seg.current_class = 1
34+
with pytest.raises(
35+
ValueError,
36+
match=re.escape(
37+
"Current class must be bewteen 1 and 3."
38+
" It cannot be 0 as 0 is the background."
39+
),
40+
):
41+
seg.current_class = 5

0 commit comments

Comments
 (0)