11from __future__ import annotations
22
3+ from numbers import Integral
34from typing import TYPE_CHECKING
45
56import numpy as np
@@ -21,7 +22,7 @@ class ImageSegmenter:
2122 def __init__ ( # type: ignore
2223 self ,
2324 img ,
24- nclasses = 1 ,
25+ classes = 1 ,
2526 mask = None ,
2627 mask_colors = None ,
2728 mask_alpha = 0.75 ,
@@ -39,8 +40,8 @@ def __init__( # type: ignore
3940 ----------
4041 img : array_like
4142 A valid argument to imshow
42- nclasses : int, default 1
43- How many classes to have in the mask.
43+ classes : int, iterable[string] , default 1
44+ If a number How many classes to have in the mask.
4445 mask : arraylike, optional
4546 If you want to pre-seed the mask
4647 mask_colors : None, color, or array of colors, optional
@@ -69,13 +70,20 @@ def __init__( # type: ignore
6970
7071 self .mask_alpha = mask_alpha
7172
73+ if isinstance (classes , Integral ):
74+ self ._classes : list [str | int ] = list (range (classes ))
75+ else :
76+ self ._classes = classes
77+ self ._n_classes = len (self ._classes )
7278 if mask_colors is None :
73- # this will break if there are more than 10 classes
74- if nclasses <= 10 :
75- self .mask_colors = to_rgba_array (list (TABLEAU_COLORS )[:nclasses ])
79+ if self ._n_classes <= 10 :
80+ # There are only 10 tableau colors
81+ self .mask_colors = to_rgba_array (
82+ list (TABLEAU_COLORS )[: self ._n_classes ]
83+ )
7684 else :
7785 # up to 949 classes. Hopefully that is always enough....
78- self .mask_colors = to_rgba_array (list (XKCD_COLORS )[:nclasses ])
86+ self .mask_colors = to_rgba_array (list (XKCD_COLORS )[: self . _n_classes ])
7987 else :
8088 self .mask_colors = to_rgba_array (np .atleast_1d (mask_colors ))
8189 # should probably check the shape here
@@ -90,8 +98,7 @@ def __init__( # type: ignore
9098 self .mask = mask
9199
92100 self ._overlay = np .zeros ((* self ._img .shape [:2 ], 4 ))
93- self .nclasses = nclasses
94- for i in range (nclasses + 1 ):
101+ for i in range (self ._n_classes + 1 ):
95102 idx = self .mask == i
96103 if i == 0 :
97104 self ._overlay [idx ] = [0 , 0 , 0 , 0 ]
@@ -160,6 +167,26 @@ def erasing(self, val: bool) -> None:
160167 raise TypeError (f"Erasing must be a bool - got type { type (val )} " )
161168 self ._erasing = val
162169
170+ @property
171+ def current_class (self ) -> int | str :
172+ return self ._classes [self ._cur_class_idx - 1 ]
173+
174+ @current_class .setter
175+ def current_class (self , val : int | str ) -> None :
176+ if isinstance (val , str ):
177+ if val not in self ._classes :
178+ raise ValueError (f"{ val } is not one of the classes: { self ._classes } " )
179+ # offset by one for the background
180+ self ._cur_class_idx = self ._classes .index (val ) + 1
181+ elif isinstance (val , Integral ):
182+ if 0 < val < self ._n_classes + 1 :
183+ self ._cur_class_idx = val
184+ else :
185+ raise ValueError (
186+ f"Current class must be bewteen 1 and { self ._n_classes } ."
187+ " It cannot be 0 as 0 is the background."
188+ )
189+
163190 def get_paths (self ) -> dict [str , list [Path ]]:
164191 """
165192 Get a dictionary of all the paths used to create the mask.
@@ -179,8 +206,8 @@ def _onselect(self, verts: Any) -> None:
179206 self ._overlay [self .indices ] = [0 , 0 , 0 , 0 ]
180207 self ._paths ["erasing" ].append (p )
181208 else :
182- self .mask [self .indices ] = self .current_class
183- self ._overlay [self .indices ] = self .mask_colors [self .current_class - 1 ]
209+ self .mask [self .indices ] = self ._cur_class_idx
210+ self ._overlay [self .indices ] = self .mask_colors [self ._cur_class_idx - 1 ]
184211 self ._paths ["adding" ].append (p )
185212
186213 self ._mask .set_data (self ._overlay )
0 commit comments