@@ -20,6 +20,15 @@ class RasterData(BaseData):
2020 url : Optional [str ] = None
2121 arr : Optional [TypedArray [Literal ['uint8' ]]] = None
2222
23+ @classmethod
24+ def from_2D_arr (cls , arr : TypedArray [Literal ['uint8' ]], ** kwargs ):
25+ if len (arr .shape ):
26+ raise ValueError (
27+ f"Found array with shape { arr .shape } . Expected two dimensions ([W,H])"
28+ )
29+ arr = np .stack ((arr ,) * 3 , axis = - 1 )
30+ return cls (arr = arr , ** kwargs )
31+
2332 def bytes_to_np (self , image_bytes : bytes ) -> np .ndarray :
2433 """
2534 Converts image bytes to a numpy array
@@ -38,9 +47,9 @@ def np_to_bytes(self, arr: np.ndarray) -> bytes:
3847 Returns:
3948 png encoded bytes
4049 """
41- if len (arr .shape ) not in [ 2 , 3 ] :
42- raise ValueError ("unsupported image format" )
43-
50+ if len (arr .shape ) != 3 :
51+ raise ValueError ("unsupported image format. Must be 3D ([H,W,C])."
52+ "Use RasterData.from_2D_arr to construct from 2D" )
4453 if arr .dtype != np .uint8 :
4554 raise TypeError (f"image data type must be uint8. Found { arr .dtype } " )
4655
@@ -72,6 +81,9 @@ def data(self) -> np.ndarray:
7281 else :
7382 raise ValueError ("Must set either url, file_path or im_bytes" )
7483
84+ def set_fetch_fn (self , fn ):
85+ object .__setattr__ (self , 'fetch_remote' , lambda : fn (self ))
86+
7587 def fetch_remote (self ) -> bytes :
7688 """
7789 Method for accessing url.
@@ -122,12 +134,19 @@ def validate_args(cls, values):
122134 raise TypeError (
123135 "Numpy array representing segmentation mask must be np.uint8"
124136 )
125- elif len (arr .shape ) not in [ 2 , 3 ] :
126- raise TypeError (
127- f"Numpy array must have 2 or 3 dims. Found shape { arr . shape } "
128- )
137+ elif len (arr .shape ) != 3 :
138+ raise ValueError (
139+ "unsupported image format. Must be 3D ([H,W,C]). "
140+ "Use RasterData.from_2D_arr to construct from 2D" )
129141 return values
130142
143+ def __repr__ (self ) -> str :
144+ symbol_or_none = lambda data : '...' if data is not None else None
145+ return f"RasterData(im_bytes={ symbol_or_none (self .im_bytes )} ," \
146+ f"file_path={ self .file_path } ," \
147+ f"url={ self .url } ," \
148+ f"arr={ symbol_or_none (self .arr )} )"
149+
131150 class Config :
132151 # Required for sharing references
133152 copy_on_model_validation = False
0 commit comments