Skip to content

Commit 18911a1

Browse files
committed
Add MediaType support to Project
1 parent 90c22ee commit 18911a1

File tree

4 files changed

+101
-8
lines changed

4 files changed

+101
-8
lines changed

labelbox/client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,20 @@ def create_project(self, **kwargs) -> Project:
611611
InvalidAttributeError: If the Project type does not contain
612612
any of the attribute names given in kwargs.
613613
"""
614+
media_type = kwargs.get("media_type")
615+
if media_type:
616+
if isinstance(media_type, Project.MediaType
617+
) and media_type is not Project.MediaType.Unknown:
618+
kwargs["media_type"] = media_type.value
619+
else:
620+
media_types = [
621+
item for item in Project.MediaType.__members__
622+
if item != "Unknown"
623+
]
624+
raise TypeError(
625+
f"{media_type} is not a supported type. Please use any of {media_types} from the {type(media_type).__name__} enumeration."
626+
)
627+
614628
return self._create(Entity.Project, kwargs)
615629

616630
def get_roles(self) -> List[Role]:

labelbox/orm/db_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _set_field_values(self, field_values):
6969
"Failed to convert value '%s' to datetime for "
7070
"field %s", value, field)
7171
elif isinstance(field.field_type, Field.EnumType):
72-
value = field.field_type.enum_cls[value]
72+
value = field.field_type.enum_cls(value)
7373
setattr(self, field.name, value)
7474

7575
def __repr__(self):

labelbox/schema/project.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import enum
1+
from enum import Enum
22
import json
33
import logging
44
import time
@@ -63,6 +63,37 @@ class Project(DbObject, Updateable, Deletable):
6363
benchmarks (Relationship): `ToMany` relationship to Benchmark
6464
ontology (Relationship): `ToOne` relationship to Ontology
6565
"""
66+
67+
class MediaType(Enum):
68+
"""add DOCUMENT, GEOSPATIAL_TILE, SIMPLE_TILE to match the UI choices"""
69+
Audio = "AUDIO"
70+
Conversational = "CONVERSATIONAL"
71+
Dicom = "DICOM"
72+
Document = "PDF"
73+
Geospatial_Tile = "TMS_GEO"
74+
Image = "IMAGE"
75+
Json = "JSON"
76+
Pdf = "PDF"
77+
Simple_Tile = "TMS_SIMPLE"
78+
Text = "TEXT"
79+
Tms_Geo = "TMS_GEO"
80+
Tms_Simple = "TMS_SIMPLE"
81+
Video = "VIDEO"
82+
Unknown = "UNKNOWN"
83+
84+
@classmethod
85+
def _missing_(cls, name):
86+
"""Handle missing null data types for projects
87+
created without setting allowedMediaType"""
88+
# return Project.MediaType.UNKNOWN
89+
90+
if name is None:
91+
return cls.Unknown
92+
93+
for member in cls.__members__:
94+
if member.name == name.upper():
95+
return member
96+
6697
name = Field.String("name")
6798
description = Field.String("description")
6899
updated_at = Field.DateTime("updated_at")
@@ -71,6 +102,8 @@ class Project(DbObject, Updateable, Deletable):
71102
last_activity_time = Field.DateTime("last_activity_time")
72103
auto_audit_number_of_labels = Field.Int("auto_audit_number_of_labels")
73104
auto_audit_percentage = Field.Float("auto_audit_percentage")
105+
# Bind data_type and allowedMediaTYpe using the GraphQL type MediaType
106+
media_type = Field.Enum(MediaType, "media_type", "allowedMediaType")
74107

75108
# Relationships
76109
datasets = Relationship.ToMany("Dataset", True)
@@ -85,7 +118,7 @@ class Project(DbObject, Updateable, Deletable):
85118
benchmarks = Relationship.ToMany("Benchmark", False)
86119
ontology = Relationship.ToOne("Ontology", True)
87120

88-
class QueueMode(enum.Enum):
121+
class QueueMode(Enum):
89122
Batch = "Batch"
90123
Dataset = "Dataset"
91124

@@ -94,6 +127,20 @@ def update(self, **kwargs):
94127
if mode:
95128
self._update_queue_mode(mode)
96129

130+
media_type = kwargs.get("media_type")
131+
if media_type:
132+
if isinstance(media_type, Project.MediaType
133+
) and media_type != Project.MediaType.Unknown:
134+
kwargs["media_type"] = media_type.value
135+
else:
136+
media_types = [
137+
item for item in Project.MediaType.__members__
138+
if item != "Unknown"
139+
]
140+
raise TypeError(
141+
f"{media_type} is not a supported type. Please use any of {media_types} from the {type(media_type).__name__} enumeration."
142+
)
143+
97144
return super().update(**kwargs)
98145

99146
def members(self) -> PaginatedCollection:

tests/integration/test_project.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import time
32
import os
43

@@ -44,11 +43,19 @@ def test_project(client, rand_gen):
4443
assert set(final) == set(before)
4544

4645

47-
@pytest.mark.skip(
48-
reason="this will fail if run multiple times, limit is defaulted to 3 per org"
49-
"add this back in when either all test orgs have unlimited, or we delete all tags befoer running"
50-
)
5146
def test_update_project_resource_tags(client, rand_gen):
47+
48+
def delete_tag(tag_id: str):
49+
"""Deletes a tag given the tag uid. Currently internal use only so this is not public"""
50+
res = client.execute(
51+
"""mutation deleteResourceTagPyApi($tag_id: String!) {
52+
deleteResourceTag(input: {id: $tag_id}) {
53+
id
54+
}
55+
}
56+
""", {"tag_id": tag_id})
57+
return res
58+
5259
before = list(client.get_projects())
5360
for o in before:
5461
assert isinstance(o, Project)
@@ -92,6 +99,9 @@ def test_update_project_resource_tags(client, rand_gen):
9299
assert len(project_resource_tag) == 1
93100
assert project_resource_tag[0].uid == tagA.uid
94101

102+
delete_tag(tagA.uid)
103+
delete_tag(tagB.uid)
104+
95105

96106
def test_project_filtering(client, rand_gen):
97107
name_1 = rand_gen(str)
@@ -191,3 +201,25 @@ def test_queue_mode(configured_project: Project):
191201
) == configured_project.QueueMode.Dataset
192202
configured_project.update(queue_mode=configured_project.QueueMode.Batch)
193203
assert configured_project.queue_mode() == configured_project.QueueMode.Batch
204+
205+
206+
def test_media_type(client, configured_project: Project, rand_gen):
207+
# Existing project
208+
assert configured_project.media_type is None or isinstance(
209+
configured_project.media_type, Project.MediaType)
210+
211+
# No media_type
212+
project = client.create_project(name=rand_gen(str))
213+
assert project.media_type == Project.MediaType.Unknown
214+
project.update(media_type=Project.MediaType.Image)
215+
assert project.media_type == Project.MediaType.Image
216+
project.delete()
217+
218+
for media_type in Project.MediaType:
219+
if media_type == Project.MediaType.Unknown:
220+
continue
221+
222+
project = client.create_project(name=rand_gen(str),
223+
media_type=media_type)
224+
assert project.media_type == media_type
225+
project.delete()

0 commit comments

Comments
 (0)