Skip to content

Commit c71bc1b

Browse files
author
Matt Sokoloff
committed
created basics notebook and related bug fixes
1 parent 00b4727 commit c71bc1b

File tree

18 files changed

+1284
-57
lines changed

18 files changed

+1284
-57
lines changed

examples/annotation_types/annotation_type_basics.ipynb

Lines changed: 1060 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .geometry import Line
2+
from .geometry import Point
3+
from .geometry import Mask
4+
from .geometry import Polygon
5+
from .geometry import Rectangle
6+
from .geometry import Geometry
7+
8+
from .annotation import ClassificationAnnotation
9+
from .annotation import VideoClassificationAnnotation
10+
from .annotation import ObjectAnnotation
11+
from .annotation import VideoObjectAnnotation
12+
13+
from .ner import TextEntity
14+
15+
from .classification import Checklist
16+
from .classification import ClassificationAnswer
17+
from .classification import Dropdown
18+
from .classification import Radio
19+
from .classification import Text
20+
21+
from .data import RasterData
22+
from .data import TextData
23+
from .data import VideoData
24+
25+
from .label import Label
26+
27+
from .collection import LabelCollection
28+
from .collection import LabelGenerator

labelbox/data/annotation_types/data/raster.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
import numpy as np
55
import requests
6+
from typing_extensions import Literal
67
from pydantic import root_validator
78
from PIL import Image
89

910
from .base_data import BaseData
11+
from ..types import TypedArray
1012

1113

1214
class RasterData(BaseData):
@@ -16,7 +18,7 @@ class RasterData(BaseData):
1618
im_bytes: Optional[bytes] = None
1719
file_path: Optional[str] = None
1820
url: Optional[str] = None
19-
arr: Optional[np.ndarray] = None
21+
arr: Optional[TypedArray[Literal['uint8']]] = None
2022

2123
def bytes_to_np(self, image_bytes: bytes) -> np.ndarray:
2224
"""
@@ -64,14 +66,23 @@ def data(self) -> np.ndarray:
6466
self.im_bytes = im_bytes
6567
return self.bytes_to_np(im_bytes)
6668
elif self.url is not None:
67-
response = requests.get(self.url)
68-
response.raise_for_status()
69-
im_bytes = response.content
69+
im_bytes = self.fetch_remote()
7070
self.im_bytes = im_bytes
7171
return self.bytes_to_np(im_bytes)
7272
else:
7373
raise ValueError("Must set either url, file_path or im_bytes")
7474

75+
def fetch_remote(self) -> bytes:
76+
"""
77+
Method for accessing url.
78+
79+
If url is not publicly accessible or requires another access pattern
80+
simply override this function
81+
"""
82+
response = requests.get(self.url)
83+
response.raise_for_status()
84+
return response.content
85+
7586
def create_url(self, signer: Callable[[bytes], str]) -> str:
7687
"""
7788
Utility for creating a url from any of the other image representations.
@@ -95,7 +106,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
95106
"One of url, im_bytes, file_path, arr must not be None.")
96107
return self.url
97108

98-
@root_validator
109+
@root_validator()
99110
def validate_args(cls, values):
100111
file_path = values.get("file_path")
101112
im_bytes = values.get("im_bytes")
@@ -118,8 +129,6 @@ def validate_args(cls, values):
118129
return values
119130

120131
class Config:
121-
# Required for numpy arrays
122-
arbitrary_types_allowed = True
123132
# Required for sharing references
124133
copy_on_model_validation = False
125134
# Required for discriminating between data types

labelbox/data/annotation_types/data/text.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,23 @@ def data(self) -> str:
3030
self.text = text
3131
return text
3232
elif self.url:
33-
response = requests.get(self.url)
34-
response.raise_for_status()
35-
text = response.text
33+
text = self.fetch_remote()
3634
self.text = text
3735
return text
3836
else:
3937
raise ValueError("Must set either url, file_path or im_bytes")
4038

39+
def fetch_remote(self) -> str:
40+
"""
41+
Method for accessing url.
42+
43+
If url is not publicly accessible or requires another access pattern
44+
simply override this function
45+
"""
46+
response = requests.get(self.url)
47+
response.raise_for_status()
48+
return response.text
49+
4150
def create_url(self, signer: Callable[[bytes], str]) -> None:
4251
"""
4352
Utility for creating a url from any of the other text references.

labelbox/data/annotation_types/data/video.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import os
33
import urllib.request
44
from typing import Callable, Dict, Generator, Optional, Tuple
5+
from typing_extensions import Literal
56
from uuid import uuid4
67

78
import cv2
89
import numpy as np
910
from pydantic import root_validator
1011

1112
from .base_data import BaseData
13+
from ..types import TypedArray
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -19,7 +21,7 @@ class VideoData(BaseData):
1921
"""
2022
file_path: Optional[str] = None
2123
url: Optional[str] = None
22-
frames: Optional[Dict[int, np.ndarray]] = None
24+
frames: Optional[Dict[int, TypedArray[Literal['uint8']]]] = None
2325

2426
def load_frames(self, overwrite: bool = False) -> None:
2527
"""
@@ -33,8 +35,14 @@ def load_frames(self, overwrite: bool = False) -> None:
3335
return
3436

3537
for count, frame in self.frame_generator():
38+
if self.frames is None:
39+
self.frames = {}
3640
self.frames[count] = frame
3741

42+
@property
43+
def data(self):
44+
return self.frame_generator()
45+
3846
def frame_generator(
3947
self,
4048
cache_frames=False,
@@ -48,26 +56,27 @@ def frame_generator(
4856
download_dir (str): Directory to save the video to. Defaults to `/tmp` dir
4957
"""
5058
if self.frames is not None:
51-
for idx, img in self.frames.items():
52-
yield idx, img
59+
for idx, frame in self.frames.items():
60+
yield idx, frame
5361
return
5462
elif self.url and not self.file_path:
5563
file_path = os.path.join(download_dir, f"{uuid4()}.mp4")
5664
logger.info("Downloading the video locally to %s", file_path)
57-
urllib.request.urlretrieve(self.url, file_path)
65+
self.fetch_remote(file_path)
5866
self.file_path = file_path
5967

6068
vidcap = cv2.VideoCapture(self.file_path)
6169

62-
success, img = vidcap.read()
70+
success, frame = vidcap.read()
6371
count = 0
64-
self.frames = {}
72+
if cache_frames:
73+
self.frames = {}
6574
while success:
66-
img = img[:, :, ::-1]
67-
yield count, img
75+
frame = frame[:, :, ::-1]
76+
yield count, frame
6877
if cache_frames:
69-
self.frames[count] = img
70-
success, img = vidcap.read()
78+
self.frames[count] = frame
79+
success, frame = vidcap.read()
7180
count += 1
7281

7382
def __getitem__(self, idx: int) -> np.ndarray:
@@ -77,6 +86,18 @@ def __getitem__(self, idx: int) -> np.ndarray:
7786
)
7887
return self.frames[idx]
7988

89+
def fetch_remote(self, local_path) -> None:
90+
"""
91+
Method for downloading data from self.url
92+
93+
If url is not publicly accessible or requires another access pattern
94+
simply override this function
95+
96+
Args:
97+
local_path: Where to save the thing too.
98+
"""
99+
urllib.request.urlretrieve(self.url, local_path)
100+
80101
def create_url(self, signer: Callable[[bytes], str]) -> None:
81102
"""
82103
Utility for creating a url from any of the other video references.
@@ -134,7 +155,5 @@ def validate_data(cls, values):
134155
return values
135156

136157
class Config:
137-
# Required for numpy arrays
138-
arbitrary_types_allowed = True
139158
# Required for discriminating between data types
140159
extra = 'forbid'

labelbox/data/annotation_types/feature.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ class FeatureSchema(BaseModel):
2121
def must_set_one(cls, values):
2222
if values['schema_id'] is None and values['name'] is None:
2323
raise ValueError(
24-
"Must set either schema_id or display name for all feature schemas"
25-
)
24+
"Must set either schema_id or name for all feature schemas")
2625
return values

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Callable, Tuple
1+
from typing import Callable, Tuple, Union
22

33
import numpy as np
4+
from pydantic.class_validators import validator
45
from rasterio.features import shapes
56
from shapely.geometry import MultiPolygon, shape
67

@@ -11,21 +12,39 @@
1112
class Mask(Geometry):
1213
# Raster data can be shared across multiple masks... or not
1314
mask: RasterData
14-
color_rgb: Tuple[int, int, int]
15+
# RGB or Grayscale
16+
color: Union[int, Tuple[int, int, int]]
1517

1618
@property
1719
def geometry(self):
18-
mask = self.mask.data
19-
mask = np.alltrue(mask == self.color_rgb, axis=2).astype(np.uint8)
20+
mask = self.raster(binary=True)
2021
polygons = (
2122
shape(shp)
2223
for shp, val in shapes(mask, mask=None)
2324
# ignore if shape is area of smaller than 1 pixel
2425
if val >= 1)
2526
return MultiPolygon(polygons).__geo_interface__
2627

27-
def raster(self) -> np.ndarray:
28-
return self.mask.data
28+
def raster(self, binary=False) -> np.ndarray:
29+
"""
30+
Removes all pixels from the segmentation mask that do not equal self.color
31+
32+
Returns:
33+
np.ndarray representing only this object
34+
"""
35+
mask = self.mask.data
36+
if len(mask.shape) == 2:
37+
mask = np.expand_dims(mask, axis=-1)
38+
mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8)
39+
if binary:
40+
return mask
41+
elif isinstance(self.color, int):
42+
return mask * self.color
43+
else:
44+
color_image = np.zeros((mask.shape[0], mask.shape[1], 3),
45+
dtype=np.uint8)
46+
color_image[mask.astype(np.bool)] = self.color
47+
return color_image
2948

3049
def create_url(self, signer: Callable[[bytes], str]) -> str:
3150
"""
@@ -38,3 +57,21 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
3857
the url for the mask
3958
"""
4059
return self.mask.create_url(signer)
60+
61+
@validator('color')
62+
def is_valid_color(cls, color):
63+
#Does the dtype matter? Can it be a float?
64+
if isinstance(color, (tuple, list)):
65+
if len(color) != 3:
66+
raise ValueError(
67+
"Segmentation colors must be either a (r,g,b) tuple or a single grayscale value"
68+
)
69+
elif not all([0 <= c <= 255 for c in color]):
70+
raise ValueError(
71+
f"All rgb colors must be between 0 and 255. Found : {color}"
72+
)
73+
elif not (0 <= color <= 255):
74+
raise ValueError(
75+
f"All rgb colors must be between 0 and 255. Found : {color}")
76+
77+
return color

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
class Rectangle(Geometry):
1010
"""
1111
Represents a 2d rectangle. Also known as a bounding box.
12+
13+
start: Top left coordinate of the rectangle
14+
end: Bottom right coordinate of the rectangle
1215
"""
1316
start: Point
1417
end: Point

labelbox/data/annotation_types/label.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .annotation import (ClassificationAnnotation, ObjectAnnotation,
1212
VideoClassificationAnnotation, VideoObjectAnnotation)
1313

14+
from pydantic import validator
15+
1416

1517
class Label(BaseModel):
1618
data: Union[VideoData, RasterData, TextData]
@@ -65,7 +67,7 @@ def create_data_row(self, dataset: "Entity.Dataset",
6567
Returns:
6668
Label with updated references to new data row
6769
"""
68-
args = {'row_data': self.add_url_to_data(signer)}
70+
args = {'row_data': self.data.create_url(signer)}
6971
if self.data.external_id is not None:
7072
args.update({'external_id': self.data.external_id})
7173

@@ -151,3 +153,19 @@ def _assign_option(self, classification: ClassificationAnnotation,
151153
raise TypeError(
152154
f"Unexpected type for answer found. {type(classification.value.answer)}"
153155
)
156+
157+
@validator("annotations", pre=True)
158+
def validate_union(cls, value):
159+
supported = tuple([
160+
field.type_
161+
for field in cls.__fields__['annotations'].sub_fields[0].sub_fields
162+
])
163+
if not isinstance(value, list):
164+
raise TypeError(f"Annotations must be a list. Found {type(value)}")
165+
166+
for v in value:
167+
if not isinstance(v, supported):
168+
raise TypeError(
169+
f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}"
170+
)
171+
return value
Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,29 @@
1-
from pydantic import Field
1+
from typing import Generic, TypeVar
22
from typing_extensions import Annotated
33

4+
from pydantic import Field
5+
from pydantic.fields import ModelField
6+
import numpy as np
7+
48
Cuid = Annotated[str, Field(min_length=25, max_length=25)]
9+
10+
DType = TypeVar('DType')
11+
12+
13+
class TypedArray(np.ndarray, Generic[DType]):
14+
15+
@classmethod
16+
def __get_validators__(cls):
17+
yield cls.validate
18+
19+
@classmethod
20+
def validate(cls, val, field: ModelField):
21+
if not isinstance(val, np.ndarray):
22+
raise TypeError(f"Expected numpy array. Found {type(val)}")
23+
dtype_field = field.sub_fields[0]
24+
actual_dtype = dtype_field.type_.__args__[0]
25+
if val.dtype != actual_dtype:
26+
raise TypeError(
27+
f"Expected numpy array have type {actual_dtype}. Found {val.dtype}"
28+
)
29+
return val

0 commit comments

Comments
 (0)