|
1 | | -from typing import Callable, Optional |
2 | | -from io import BytesIO |
3 | 1 | from abc import ABC |
| 2 | +from io import BytesIO |
| 3 | +from typing import Callable, Optional |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | | -from pydantic import BaseModel |
7 | 6 | import requests |
| 7 | +from PIL import Image |
8 | 8 | from google.api_core import retry |
9 | | -from typing_extensions import Literal |
| 9 | +from pydantic import BaseModel |
10 | 10 | from pydantic import root_validator |
11 | | -from PIL import Image |
| 11 | +from typing_extensions import Literal |
12 | 12 |
|
13 | 13 | from .base_data import BaseData |
14 | 14 | from ..types import TypedArray |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class RasterData(BaseModel, ABC): |
18 | 18 | """Represents an image or segmentation mask. |
19 | | -
|
20 | 19 | """ |
21 | 20 | im_bytes: Optional[bytes] = None |
22 | 21 | file_path: Optional[str] = None |
23 | 22 | url: Optional[str] = None |
24 | 23 | arr: Optional[TypedArray[Literal['uint8']]] = None |
25 | 24 |
|
26 | | - |
27 | 25 | @classmethod |
28 | | - def from_2D_arr(cls, arr: TypedArray[Literal['uint8']], **kwargs): |
29 | | - """Construct |
| 26 | + def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], TypedArray[Literal['int']]], **kwargs): |
| 27 | + """Construct from a 2D numpy array |
30 | 28 |
|
31 | 29 | Args: |
32 | | - arr: |
33 | | - **kwargs: |
| 30 | + arr: uint8 compatible numpy array |
34 | 31 |
|
35 | 32 | Returns: |
36 | | -
|
| 33 | + RasterData |
37 | 34 | """ |
38 | 35 |
|
39 | 36 | if len(arr.shape) != 2: |
40 | 37 | raise ValueError( |
41 | 38 | f"Found array with shape {arr.shape}. Expected two dimensions [H, W]" |
42 | 39 | ) |
| 40 | + |
| 41 | + if not np.issubdtype(arr.dtype, np.integer): |
| 42 | + raise ValueError("Array must be an integer subtype") |
| 43 | + |
| 44 | + if np.can_cast(arr, np.uint8): |
| 45 | + arr = arr.astype(np.uint8) |
| 46 | + else: |
| 47 | + raise ValueError("Could not cast array to uint8, check that values are between 0 and 255") |
| 48 | + |
43 | 49 | arr = np.stack((arr,) * 3, axis=-1) |
44 | 50 | return cls(arr=arr, **kwargs) |
45 | 51 |
|
@@ -164,10 +170,10 @@ def validate_args(cls, values): |
164 | 170 |
|
165 | 171 | def __repr__(self) -> str: |
166 | 172 | symbol_or_none = lambda data: '...' if data is not None else None |
167 | | - return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ |
168 | | - f"file_path={self.file_path}," \ |
169 | | - f"url={self.url}," \ |
170 | | - f"arr={symbol_or_none(self.arr)})" |
| 173 | + return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ |
| 174 | + f"file_path={self.file_path}," \ |
| 175 | + f"url={self.url}," \ |
| 176 | + f"arr={symbol_or_none(self.arr)})" |
171 | 177 |
|
172 | 178 | class Config: |
173 | 179 | # Required for sharing references |
@@ -198,6 +204,5 @@ class MaskData(RasterData): |
198 | 204 | """ |
199 | 205 |
|
200 | 206 |
|
201 | | - |
202 | 207 | class ImageData(RasterData, BaseData): |
203 | 208 | ... |
0 commit comments