1- from typing import Callable , Optional
2- from io import BytesIO
31from abc import ABC
4-
2+ from io import BytesIO
3+ from typing import Callable , Optional , Union
4+ from typing_extensions import Literal
55import numpy as np
6- from pydantic import BaseModel
76import requests
7+ from PIL import Image
88from google .api_core import retry
9- from typing_extensions import Literal
9+ from pydantic import BaseModel
1010from pydantic import root_validator
11- from PIL import Image
1211
1312from .base_data import BaseData
1413from ..types import TypedArray
1514
1615
1716class RasterData (BaseModel , ABC ):
18- """
19- Represents an image or segmentation mask.
17+ """Represents an image or segmentation mask.
2018 """
2119 im_bytes : Optional [bytes ] = None
2220 file_path : Optional [str ] = None
2321 url : Optional [str ] = None
2422 arr : Optional [TypedArray [Literal ['uint8' ]]] = None
2523
2624 @classmethod
27- def from_2D_arr (cls , arr : TypedArray [Literal ['uint8' ]], ** kwargs ):
25+ def from_2D_arr (cls , arr : Union [TypedArray [Literal ['uint8' ]],
26+ TypedArray [Literal ['int' ]]], ** kwargs ):
27+ """Construct from a 2D numpy array
28+
29+ Args:
30+ arr: uint8 compatible numpy array
31+
32+ Returns:
33+ RasterData
34+ """
35+
2836 if len (arr .shape ) != 2 :
2937 raise ValueError (
30- f"Found array with shape { arr .shape } . Expected two dimensions ([W,H])"
38+ f"Found array with shape { arr .shape } . Expected two dimensions [H, W]"
39+ )
40+
41+ if not np .issubdtype (arr .dtype , np .integer ):
42+ raise ValueError ("Array must be an integer subtype" )
43+
44+ if np .can_cast (arr , np .uint8 ):
45+ arr = arr .astype (np .uint8 )
46+ else :
47+ raise ValueError (
48+ "Could not cast array to uint8, check that values are between 0 and 255"
3149 )
50+
3251 arr = np .stack ((arr ,) * 3 , axis = - 1 )
3352 return cls (arr = arr , ** kwargs )
3453
@@ -153,10 +172,10 @@ def validate_args(cls, values):
153172
154173 def __repr__ (self ) -> str :
155174 symbol_or_none = lambda data : '...' if data is not None else None
156- return f"{ self .__class__ .__name__ } (im_bytes={ symbol_or_none (self .im_bytes )} ," \
157- f"file_path={ self .file_path } ," \
158- f"url={ self .url } ," \
159- f"arr={ symbol_or_none (self .arr )} )"
175+ return f"{ self .__class__ .__name__ } (im_bytes={ symbol_or_none (self .im_bytes )} ," \
176+ f"file_path={ self .file_path } ," \
177+ f"url={ self .url } ," \
178+ f"arr={ symbol_or_none (self .arr )} )"
160179
161180 class Config :
162181 # Required for sharing references
@@ -166,7 +185,25 @@ class Config:
166185
167186
168187class MaskData (RasterData ):
169- ...
188+ """Used to represent a segmentation Mask
189+
190+ All segments within a mask must be mutually exclusive. At a
191+ single cell, only one class can be present. All Mask data is
192+ converted to a [H,W,3] image. Classes are
193+
194+ >>> # 3x3 mask with two classes and back ground
195+ >>> MaskData.from_2D_arr([
196+ >>> [0, 0, 0],
197+ >>> [1, 1, 1],
198+ >>> [2, 2, 2],
199+ >>>])
200+
201+ Args:
202+ im_bytes: Optional[bytes] = None
203+ file_path: Optional[str] = None
204+ url: Optional[str] = None
205+ arr: Optional[TypedArray[Literal['uint8']]] = None
206+ """
170207
171208
172209class ImageData (RasterData , BaseData ):
0 commit comments