Skip to content

Commit 029cc49

Browse files
author
Matt Sokoloff
committed
update tests
1 parent 55360e1 commit 029cc49

File tree

15 files changed

+204
-46
lines changed

15 files changed

+204
-46
lines changed

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def geometry(self) -> geojson.geometry.Geometry:
2727
]])
2828

2929
def raster(self, height: int, width: int,
30-
color: int = (255, 255, 255)) -> np.ndarray:
30+
color = (255, 255, 255)) -> np.ndarray:
3131
"""
3232
Draw the rectangle onto a 3d mask
3333

labelbox/data/ontology.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from typing import Dict, List, Tuple
2+
3+
from labelbox.schema import ontology
4+
from .annotation_types import (
5+
Text,
6+
Dropdown,
7+
Checklist,
8+
Radio,
9+
ClassificationAnnotation,
10+
ObjectAnnotation,
11+
Mask,
12+
Point,
13+
Line,
14+
Polygon,
15+
Rectangle,
16+
TextEntity
17+
)
18+
19+
20+
def get_feature_schema_lookup(
21+
ontology_builder: ontology.OntologyBuilder
22+
) -> Tuple[Dict[str, str], Dict[str, str]]:
23+
tool_lookup = {}
24+
classification_lookup = {}
25+
26+
def flatten_classification(classifications):
27+
for classification in classifications:
28+
if isinstance(classification, ontology.Classification):
29+
classification_lookup[
30+
classification.
31+
instructions] = classification.feature_schema_id
32+
elif isinstance(classification, ontology.Option):
33+
classification_lookup[
34+
classification.value] = classification.feature_schema_id
35+
else:
36+
raise TypeError(
37+
f"Unexpected type found in ontology. `{type(classification)}`"
38+
)
39+
flatten_classification(classification.options)
40+
41+
for tool in ontology_builder.tools:
42+
tool_lookup[tool.name] = tool.feature_schema_id
43+
flatten_classification(tool.classifications)
44+
flatten_classification(ontology_builder.classifications)
45+
return tool_lookup, classification_lookup
46+
47+
48+
def _get_options(annotation: ClassificationAnnotation,
49+
existing_options: List[ontology.Option]):
50+
if isinstance(annotation.value, Radio):
51+
answers = [annotation.value.answer]
52+
elif isinstance(annotation.value, Text):
53+
return existing_options
54+
elif isinstance(annotation.value, (Checklist, Dropdown)):
55+
answers = annotation.value.answer
56+
else:
57+
raise TypeError(
58+
f"Expected one of Radio, Text, Checklist, Dropdown. Found {type(annotation.value)}"
59+
)
60+
61+
option_names = {option.value for option in existing_options}
62+
for answer in answers:
63+
if answer.name not in option_names:
64+
existing_options.append(ontology.Option(value=answer.name))
65+
option_names.add(answer.name)
66+
return existing_options
67+
68+
69+
def get_classifications(
70+
annotations: List[ClassificationAnnotation],
71+
existing_classifications: List[ontology.Classification]
72+
) -> List[ontology.Classification]:
73+
existing_classifications = {
74+
classification.instructions: classification
75+
for classification in existing_classifications
76+
}
77+
for annotation in annotations:
78+
# If the classification exists then we just want to add options to it
79+
classification_feature = existing_classifications.get(annotation.name)
80+
if classification_feature:
81+
classification_feature.options = _get_options(
82+
annotation, classification_feature.options)
83+
elif annotation.name not in existing_classifications:
84+
existing_classifications[annotation.name] = ontology.Classification(
85+
class_type=classification_mapping(annotation),
86+
instructions=annotation.name,
87+
options=_get_options(annotation, []))
88+
return list(existing_classifications.values())
89+
90+
91+
def get_tools(annotations: List[ObjectAnnotation],
92+
existing_tools: List[ontology.Classification]):
93+
existing_tools = {tool.name: tool for tool in existing_tools}
94+
for annotation in annotations:
95+
if annotation.name in existing_tools:
96+
# We just want to update classifications
97+
existing_tools[
98+
annotation.name].classifications = get_classifications(
99+
annotation.classifications,
100+
existing_tools[annotation.name].classifications)
101+
else:
102+
existing_tools[annotation.name] = ontology.Tool(
103+
tool=tool_mapping(annotation),
104+
name=annotation.name,
105+
classifications=get_classifications(annotation.classifications,
106+
[]))
107+
return list(existing_tools.values())
108+
109+
110+
def tool_mapping(annotation):
111+
tool_types = ontology.Tool.Type
112+
mapping = {
113+
Mask: tool_types.SEGMENTATION,
114+
Polygon: tool_types.POLYGON,
115+
Point: tool_types.POINT,
116+
Rectangle: tool_types.BBOX,
117+
Line: tool_types.LINE,
118+
TextEntity: tool_types.NER,
119+
}
120+
result = mapping.get(type(annotation.value))
121+
if result is None:
122+
raise TypeError(
123+
f"Unexpected type found. {type(annotation.value)}. Expected one of {list(mapping.keys())}"
124+
)
125+
return result
126+
127+
128+
def classification_mapping(annotation):
129+
classification_types = ontology.Classification.Type
130+
mapping = {
131+
Text: classification_types.TEXT,
132+
Checklist: classification_types.CHECKLIST,
133+
Radio: classification_types.RADIO,
134+
Dropdown: classification_types.DROPDOWN
135+
}
136+
result = mapping.get(type(annotation.value))
137+
if result is None:
138+
raise TypeError(
139+
f"Unexpected type found. {type(annotation.value)}. Expected one of {list(mapping.keys())}"
140+
)
141+
return result

tests/data/annotation_types/classification/test_classification.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import pytest
2-
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
3-
from labelbox.data.annotation_types.classification import (Checklist,
4-
ClassificationAnswer,
5-
Dropdown, Radio,
6-
Text)
72
from pydantic import ValidationError
83

4+
from labelbox.data.annotation_types import (
5+
Checklist,
6+
ClassificationAnswer,
7+
Dropdown,
8+
Radio,
9+
Text,
10+
ClassificationAnnotation
11+
)
12+
913

1014
def test_classification_answer():
1115
with pytest.raises(ValidationError):

tests/data/annotation_types/data/test_raster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
import numpy as np
55
import pytest
6-
from labelbox.data.annotation_types.data.raster import RasterData
76
from PIL import Image
87
from pydantic import ValidationError
98

9+
from labelbox.data.annotation_types.data import RasterData
10+
1011

1112
def test_validate_schema():
1213
with pytest.raises(ValidationError):

tests/data/annotation_types/data/test_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
2-
from labelbox.data.annotation_types.data.text import TextData
32
from pydantic import ValidationError
43

4+
from labelbox.data.annotation_types import TextData
55

66
def test_validate_schema():
77
with pytest.raises(ValidationError):

tests/data/annotation_types/data/test_video.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
22
import pytest
3-
from labelbox.data.annotation_types.data.video import VideoData
43
from pydantic import ValidationError
54

5+
from labelbox.data.annotation_types import VideoData
6+
67

78
def test_validate_schema():
89
with pytest.raises(ValidationError):

tests/data/annotation_types/geometry/test_line.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import pytest
33
import cv2
44

5-
from labelbox.data.annotation_types.geometry import Point
6-
from labelbox.data.annotation_types.geometry import Line
5+
from labelbox.data.annotation_types.geometry import Point, Line
76

87

98
def test_line():
@@ -21,4 +20,4 @@ def test_line():
2120
assert line.shapely.__geo_interface__ == expected
2221

2322
raster = line.raster(height=32, width=32)
24-
assert (cv2.imread("tests/data/assets/line.png")[:, :, 0] == raster).all()
23+
assert (cv2.imread("tests/data/assets/line.png") == raster).all()

tests/data/annotation_types/geometry/test_mask.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import numpy as np
55
import cv2
66

7-
from labelbox.data.annotation_types.geometry import Point, Rectangle, Mask
8-
from labelbox.data.annotation_types.data.raster import RasterData
7+
from labelbox.data.annotation_types import RasterData, Point, Rectangle, Mask
98

109

1110
def test_mask():
@@ -50,10 +49,10 @@ def test_mask():
5049
gt1 = Rectangle(start=Point(x=0, y=0),
5150
end=Point(x=10, y=10)).raster(height=raster1.shape[0],
5251
width=raster1.shape[1],
53-
color=1)
52+
color=(255,255,255))
5453
gt2 = Rectangle(start=Point(x=20, y=20),
5554
end=Point(x=30, y=30)).raster(height=raster2.shape[0],
5655
width=raster2.shape[1],
57-
color=1)
58-
assert ((np.max(raster1, axis=-1) > 0) == gt1).all()
59-
assert ((np.max(raster2, axis=-1) > 0) == gt2).all()
56+
color=(0,255,255))
57+
assert (raster1 == gt1).all()
58+
assert (raster2 == gt2).all()

tests/data/annotation_types/geometry/test_point.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import cv2
44

5-
from labelbox.data.annotation_types.geometry import Point
5+
from labelbox.data.annotation_types import Point
66

77

88
def test_point():
@@ -19,4 +19,4 @@ def test_point():
1919
assert point.shapely.__geo_interface__ == expected
2020

2121
raster = point.raster(height=32, width=32)
22-
assert (cv2.imread("tests/data/assets/point.png")[:, :, 0] == raster).all()
22+
assert (cv2.imread("tests/data/assets/point.png") == raster).all()

tests/data/annotation_types/geometry/test_polygon.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import pytest
33
import cv2
44

5-
from labelbox.data.annotation_types.geometry import Point
6-
from labelbox.data.annotation_types.geometry import Polygon
5+
from labelbox.data.annotation_types import Polygon, Point
76

87

98
def test_polygon():
@@ -25,5 +24,4 @@ def test_polygon():
2524
assert polygon.shapely.__geo_interface__ == expected
2625

2726
raster = polygon.raster(10, 10)
28-
assert (cv2.imread("tests/data/assets/polygon.png")[:, :,
29-
0] == raster).all()
27+
assert (cv2.imread("tests/data/assets/polygon.png") == raster).all()

0 commit comments

Comments
 (0)