Skip to content

Commit cca0ea2

Browse files
Support Organization relationships (query without source ID)
1 parent f9b9a0a commit cca0ea2

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed

labelbox/orm/db_object.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(self, source, relationship):
103103
self.relationship = relationship
104104
self.supports_filtering = True
105105
self.supports_sorting = True
106+
self.filter_on_id = True
106107

107108
def __call__(self, *args, **kwargs ):
108109
""" Forwards the call to either `_to_many` or `_to_one` methods,
@@ -135,10 +136,12 @@ def _to_many(self, where=None, order_by=None):
135136
not_deleted = rel.destination_type.deleted == False
136137
where = not_deleted if where is None else where & not_deleted
137138

138-
query_string, params = query.relationship(self.source, rel, where, order_by)
139+
query_string, params = query.relationship(
140+
self.source if self.filter_on_id else type(self.source),
141+
rel, where, order_by)
139142
return PaginatedCollection(
140143
self.source.client, query_string, params,
141-
[utils.camel_case(type(self.source).type_name()),
144+
[utils.camel_case(self.source.type_name()),
142145
rel.graphql_name],
143146
rel.destination_type)
144147

labelbox/orm/query.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def relationship(source, relationship, where, order_by):
260260
pagination parameters.
261261
262262
Args:
263-
source (DbObject): A database object whose related objects are sought.
263+
source (DbObject or type): If a `DbObject` then the source of the
264+
relationship (the query originates from that particular object).
265+
If `type`, then the source of the relationship is implicit, even
266+
without the ID. Used for expanding from Organization.
264267
relationship (Relationship): The relationship.
265268
where (Comparison, LogicalExpression or None): The `where` clause
266269
for filtering.
@@ -274,11 +277,11 @@ def relationship(source, relationship, where, order_by):
274277
to_many = relationship.relationship_type == Relationship.Type.ToMany
275278
subquery = Query(relationship.graphql_name, relationship.destination_type,
276279
where, to_many, order_by)
277-
source_type_name = type(source).type_name()
278-
query = Query(utils.camel_case(source_type_name), subquery,
279-
type(source).uid == source.uid)
280+
query_where = type(source).uid == source.uid if isinstance(source, Entity) \
281+
else None
282+
query = Query(utils.camel_case(source.type_name()), subquery, query_where)
280283
return query.format_top(
281-
"Get" + source_type_name + utils.title_case(relationship.graphql_name))
284+
"Get" + source.type_name() + utils.title_case(relationship.graphql_name))
282285

283286

284287
def create(entity, data):

labelbox/schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,15 @@ class User(DbObject):
387387

388388

389389
class Organization(DbObject):
390+
391+
# RelationshipManagers in Organization use the type in Query (and
392+
# not the source object) because the server-side does not support
393+
# filtering on ID in the query for getting a single organization.
394+
def __init__(self, *args, **kwargs):
395+
super().__init__(*args, **kwargs)
396+
for relationship in self.relationships():
397+
getattr(self, relationship.name).filter_on_id = False
398+
390399
updated_at = Field.DateTime("updated_at")
391400
created_at = Field.DateTime("created_at")
392401
name = Field.String("name")

tests/integration/test_user_and_org.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ def test_user(client):
1313
def test_organization(client):
1414
organization = client.get_organization()
1515
assert organization.uid is not None
16-
17-
# TODO make organization fetchable on ID
18-
with pytest.raises(InvalidQueryError):
19-
list(organization.users())
20-
list(organization.projects())
16+
assert client.get_user() in set(organization.users())
2117

2218

2319
def test_user_and_org_projects(client, rand_gen):
2420
user = client.get_user()
25-
projects = set(user.projects())
21+
org = client.get_organization()
22+
user_projects = set(user.projects())
23+
org_projects = set(org.projects())
2624

2725
project = client.create_project(name=rand_gen(Project.name))
2826
assert project.created_by() == user
29-
assert set(user.projects()) == projects.union({project})
27+
assert project.organization() == org
28+
assert set(user.projects()) == user_projects.union({project})
29+
assert set(org.projects()) == org_projects.union({project})
3030

3131
project.delete()

tests/integration/test_webhook.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_webhook_create_update(client, rand_gen):
1818
assert webhook.topics == topics
1919
assert webhook.status == Webhook.ACTIVE
2020
assert list(project.webhooks()) == [webhook]
21+
assert webhook in set(client.get_organization().webhooks())
2122

2223
webhook.update(status=Webhook.REVOKED, topics=[Webhook.LABEL_UPDATED])
2324
assert webhook.topics == [Webhook.LABEL_UPDATED]

0 commit comments

Comments
 (0)