|
| 1 | +from ast import Bytes |
| 2 | +from io import BytesIO |
1 | 3 | from typing import Any, Dict, List, Tuple, Union |
| 4 | +import base64 |
| 5 | +import numpy as np |
2 | 6 |
|
3 | 7 | from pydantic import BaseModel |
| 8 | +from PIL import Image |
4 | 9 |
|
5 | 10 | from ...annotation_types.data import ImageData, TextData, MaskData |
6 | 11 | from ...annotation_types.ner import TextEntity |
@@ -113,28 +118,48 @@ def from_common(cls, rectangle: Rectangle, |
113 | 118 | classifications=classifications) |
114 | 119 |
|
115 | 120 |
|
116 | | -class _Mask(BaseModel): |
| 121 | +class _URIMask(BaseModel): |
117 | 122 | instanceURI: str |
118 | 123 | colorRGB: Tuple[int, int, int] |
119 | 124 |
|
120 | 125 |
|
| 126 | +class _PNGMask(BaseModel): |
| 127 | + png: str |
| 128 | + |
| 129 | + |
121 | 130 | class NDMask(NDBaseObject): |
122 | | - mask: _Mask |
| 131 | + mask: Union[_URIMask, _PNGMask] |
123 | 132 |
|
124 | 133 | def to_common(self) -> Mask: |
125 | | - return Mask(mask=MaskData(url=self.mask.instanceURI), |
126 | | - color=self.mask.colorRGB) |
| 134 | + if isinstance(self.mask, _URIMask): |
| 135 | + return Mask(mask=MaskData(url=self.mask.instanceURI), |
| 136 | + color=self.mask.colorRGB) |
| 137 | + else: |
| 138 | + encoded_image_bytes = self.mask.png.encode('utf-8') |
| 139 | + image_bytes = base64.b64decode(encoded_image_bytes) |
| 140 | + image = np.array(Image.open(BytesIO(image_bytes))) |
| 141 | + if np.max(image) > 1: |
| 142 | + raise ValueError( |
| 143 | + f"Expected binary mask. Found max value of {np.max(image)}") |
| 144 | + # Color is 1,1,1 because it is a binary array and we are just stacking it into 3 channels |
| 145 | + return Mask(mask=MaskData.from_2D_arr(image), color=(1, 1, 1)) |
127 | 146 |
|
128 | 147 | @classmethod |
129 | 148 | def from_common(cls, mask: Mask, |
130 | 149 | classifications: List[ClassificationAnnotation], |
131 | 150 | feature_schema_id: Cuid, extra: Dict[str, Any], |
132 | 151 | data: Union[ImageData, TextData]) -> "NDMask": |
133 | | - if mask.mask.url is None: |
134 | | - raise ValueError( |
135 | | - "Mask does not have a url. Use `LabelGenerator.add_url_to_masks`, `LabelList.add_url_to_masks`, or `Label.add_url_to_masks`." |
136 | | - ) |
137 | | - return cls(mask=_Mask(instanceURI=mask.mask.url, colorRGB=mask.color), |
| 152 | + |
| 153 | + if mask.mask.url is not None: |
| 154 | + lbv1_mask = _URIMask(instanceURI=mask.mask.url, colorRGB=mask.color) |
| 155 | + else: |
| 156 | + binary = np.all(mask.mask.value == mask.color, axis=-1) |
| 157 | + im_bytes = BytesIO() |
| 158 | + Image.fromarray(binary, 'L').save(im_bytes, format="PNG") |
| 159 | + lbv1_mask = _PNGMask( |
| 160 | + png=base64.b64encode(im_bytes.getvalue()).decode('utf-8')) |
| 161 | + |
| 162 | + return cls(mask=lbv1_mask, |
138 | 163 | dataRow=DataRow(id=data.uid), |
139 | 164 | schema_id=feature_schema_id, |
140 | 165 | uuid=extra.get('uuid'), |
|
0 commit comments