Skip to content

Commit 338fcb5

Browse files
Support Project.labels filting on Dataset
1 parent 9e55dd2 commit 338fcb5

File tree

5 files changed

+92
-32
lines changed

5 files changed

+92
-32
lines changed

labelbox/orm/model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,7 @@ def __init__(self, relationship_type, destination_type_name,
194194

195195
@property
196196
def destination_type(self):
197-
for t in Entity.subclasses():
198-
if t.type_name() == self.destination_type_name:
199-
return t
200-
raise LabelboxError("Failed to find Entity for name: %s" %
201-
self.destination_type_name)
197+
return Entity.named(self.destination_type_name)
202198

203199

204200
class Entity:
@@ -279,3 +275,20 @@ def subclasses(cls):
279275
yield subclass
280276
for subsub in subclass.subclasses():
281277
yield subsub
278+
279+
@classmethod
280+
def named(cls, name):
281+
""" Returns an Entity (direct or indirect subclass of `Entity`) that
282+
whose class name is equal to `name`.
283+
284+
Args:
285+
name (str): Name of the sought entity, for example "Project".
286+
Return:
287+
An Entity subtype that has the given `name`.
288+
Raises:
289+
LabelboxError: if there is no such class.
290+
"""
291+
for t in Entity.subclasses():
292+
if t.type_name() == name:
293+
return t
294+
raise LabelboxError("Failed to find Entity for name: %s" % name)

labelbox/orm/query.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def unset_labeling_parameter_overrides(project, data_rows):
374374
return query_str, {}
375375

376376

377-
def create_metadata(asset_type, meta_type, meta_value, data_row_id):
377+
def create_metadata(meta_type, meta_value, data_row_id):
378378
meta_type_param = "meta_type"
379379
meta_value_param = "meta_value"
380380
data_row_id_param = "data_row_id"
@@ -384,7 +384,8 @@ def create_metadata(asset_type, meta_type, meta_value, data_row_id):
384384
metaType: $%s metaValue: $%s dataRowId: $%s}) {%s}} """ % (
385385
meta_type_param, meta_value_param, data_row_id_param,
386386
meta_type_param, meta_value_param, data_row_id_param,
387-
" ".join(field.graphql_name for field in asset_type.fields()))
387+
" ".join(field.graphql_name for field
388+
in Entity.named("AssetMetadata").fields()))
388389
return query_str, {meta_type_param: meta_type,
389390
meta_value_param: meta_value,
390391
data_row_id_param: data_row_id}
@@ -484,6 +485,39 @@ def delete(db_object):
484485
return query_str, {id_param: db_object.uid}
485486

486487

488+
def project_labels(project, datasets, order_by):
489+
""" Returns the query and params for getting a Project's labels
490+
relationship. A non-standard relationship query is used to support
491+
filtering on Datasets.
492+
Args:
493+
datasets (list or None): The datasets filter. If None it's
494+
ignored.
495+
Return:
496+
(query_string, params)
497+
"""
498+
label_entity = Entity.named("Label")
499+
500+
if datasets is not None:
501+
where = " where:{dataRow: {dataset: {id_in: [%s]}}}" % ", ".join(
502+
'"%s"' % dataset.uid for dataset in datasets)
503+
else:
504+
where = ""
505+
506+
if order_by is not None:
507+
check_order_by_clause(label_entity, order_by)
508+
order_by_str = "orderBy: %s_%s" % (
509+
order_by[0].graphql_name, order_by[1].name.upper())
510+
else:
511+
order_by_str = ""
512+
513+
query_str = """query GetProjectLabelsPyApi($project_id: ID!)
514+
{project (where: {id: $project_id})
515+
{labels (skip: %%d first: %%d%s%s) {%s}}}""" % (
516+
where, order_by_str, " ".join(f.graphql_name
517+
for f in label_entity.fields()))
518+
return query_str, {"project_id": project.uid}
519+
520+
487521
def export_labels():
488522
""" Returns the query and ID param for exporting a Project's
489523
labels.
@@ -521,9 +555,10 @@ def bulk_delete(db_objects, use_where_clause):
521555
return query_str, {}
522556

523557

524-
def create_webhook(entity, topics, url, secret, project):
558+
def create_webhook(topics, url, secret, project):
525559
project_str = "" if project is None else ("project:{id:\"%s\"}," % project.uid)
526-
fields_str = " ".join(field.graphql_name for field in entity.fields())
560+
fields_str = " ".join(field.graphql_name for field
561+
in Entity.named("Webhook").fields())
527562

528563
query_str = """mutation CreateWebhookPyApi {
529564
createWebhook(data:{%s topics:{set:[%s]}, url:"%s", secret:"%s" }){%s}

labelbox/schema.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from labelbox.orm.db_object import (DbObject, Updateable, Deletable,
1111
BulkDeletable)
1212
from labelbox.orm.model import Field, Relationship
13+
from labelbox.pagination import PaginatedCollection
1314

1415

1516
""" Client-side object type definitions. """
@@ -19,12 +20,6 @@
1920

2021

2122
class Project(DbObject, Updateable, Deletable):
22-
23-
def __init__(self, *args, **kwargs):
24-
super().__init__(*args, **kwargs)
25-
26-
self.labels.supports_filtering = False
27-
2823
name = Field.String("name")
2924
description = Field.String("description")
3025
updated_at = Field.DateTime("updated_at")
@@ -41,7 +36,6 @@ def __init__(self, *args, **kwargs):
4136
labeling_frontend = Relationship.ToOne("LabelingFrontend")
4237
labeling_frontend_options = Relationship.ToMany(
4338
"LabelingFrontendOptions", False, "labeling_frontend_options")
44-
labels = Relationship.ToMany("Label", True)
4539
labeling_parameter_overrides = Relationship.ToMany(
4640
"LabelingParameterOverride", False, "labeling_parameter_overrides")
4741
webhooks = Relationship.ToMany("Webhook", False)
@@ -72,6 +66,11 @@ def create_label(self, **kwargs):
7266
res = res["data"]["createLabel"]
7367
return Label(self.client, res)
7468

69+
def labels(self, datasets=None, order_by=None):
70+
query_string, params = query.project_labels(self, datasets, order_by)
71+
return PaginatedCollection(self.client, query_string, params,
72+
["project", "labels"], Label)
73+
7574
def export_labels(self, timeout_seconds=60):
7675
""" Calls the server-side Label exporting that generates a JSON
7776
payload, and returns the URL to that payload.
@@ -311,8 +310,7 @@ def create_metadata(self, meta_type, meta_value):
311310
Return:
312311
AssetMetadata DB object.
313312
"""
314-
query_str, params = query.create_metadata(
315-
AssetMetadata, meta_type, meta_value, self.uid)
313+
query_str, params = query.create_metadata(meta_type, meta_value, self.uid)
316314
res = self.client.execute(query_str, params)
317315
return AssetMetadata(self.client, res["data"]["createAssetMetadata"])
318316

@@ -454,8 +452,7 @@ class Webhook(DbObject):
454452

455453
@staticmethod
456454
def create(client, topics, url, secret, project):
457-
query_str, params = query.create_webhook(Webhook, topics, url, secret,
458-
project)
455+
query_str, params = query.create_webhook(topics, url, secret, project)
459456
res = client.execute(query_str, params)
460457
return Webhook(client, res["data"]["createWebhook"])
461458

tests/integration/test_label.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,33 @@ def test_label_update(client, rand_gen):
8080

8181
def test_label_filter_order(client, rand_gen):
8282
project = client.create_project(name=rand_gen(str))
83-
dataset = client.create_dataset(name=rand_gen(str), projects=project)
84-
data_row = dataset.create_data_row(row_data=IMG_URL)
83+
dataset_1 = client.create_dataset(name=rand_gen(str), projects=project)
84+
dataset_2 = client.create_dataset(name=rand_gen(str), projects=project)
85+
data_row_1 = dataset_1.create_data_row(row_data=IMG_URL)
86+
data_row_2 = dataset_2.create_data_row(row_data=IMG_URL)
8587

86-
l1 = project.create_label(data_row=data_row, label="l1", seconds_to_label=0.3)
87-
l2 = project.create_label(data_row=data_row, label="l2", seconds_to_label=0.1)
88+
l1 = project.create_label(data_row=data_row_1, label="l1",
89+
seconds_to_label=0.3)
90+
l2 = project.create_label(data_row=data_row_2, label="l2",
91+
seconds_to_label=0.1)
8892

8993
# Labels are not visible in the project immediately.
9094
time.sleep(10)
9195

92-
# Filtering is not supported
93-
with pytest.raises(InvalidQueryError) as exc_info:
94-
project.labels(where=Label.label=="l1")
95-
assert exc_info.value.message == \
96-
"Relationship Project.labels doesn't support filtering"
96+
# Filtering supported on dataset
97+
assert set(project.labels()) == {l1, l2}
98+
assert set(project.labels(datasets=[])) == set()
99+
assert set(project.labels(datasets=[dataset_1])) == {l1}
100+
assert set(project.labels(datasets=[dataset_2])) == {l2}
101+
assert set(project.labels(datasets=[dataset_1, dataset_2])) == {l1, l2}
97102

98103
assert list(project.labels(order_by=Label.label.asc)) == [l1, l2]
99104
assert list(project.labels(order_by=Label.label.desc)) == [l2, l1]
100105
assert list(project.labels(order_by=Label.seconds_to_label.asc)) == [l2, l1]
101106
assert list(project.labels(order_by=Label.seconds_to_label.desc)) == [l1, l2]
102107

103-
dataset.delete()
108+
dataset_1.delete()
109+
dataset_2.delete()
104110
project.delete()
105111

106112

tools/db_object_doc_gen.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
sys.path.append(sys.argv[1])
1212

1313

14-
from labelbox import (Project, Dataset, DataRow, Label, User, Organization, Task,
15-
LabelingFrontend, Webhook)
14+
from labelbox import (Project, Dataset, DataRow, Label, User, Organization,
15+
Task, LabelingFrontend, Webhook)
1616
from labelbox.schema import Field, Relationship
1717

1818

19+
# Map {cls: [fields_info, ...]}
20+
more_fields = {}
21+
# Map {cls: [relationship_info, ...]}
22+
more_rels = {Project: [("labels", Label, "Many")]}
23+
24+
1925
for cls in (Project, Dataset, DataRow, Label, User, Organization, Task,
2026
LabelingFrontend, Webhook):
2127
print("")
@@ -34,3 +40,6 @@
3440
if isinstance(r, Relationship)):
3541
print(relationship.name, "|", relationship.destination_type_name, "|",
3642
relationship.relationship_type.name[2:])
43+
44+
for name, dest_type, card in more_rels.get(cls, []):
45+
print(name, "|", dest_type.type_name(), "|", card)

0 commit comments

Comments
 (0)