Skip to content

Commit 9a17864

Browse files
author
Matt Sokoloff
committed
automatic queries for cached relationships
1 parent 89d3d53 commit 9a17864

File tree

9 files changed

+91
-6
lines changed

9 files changed

+91
-6
lines changed

labelbox/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __init__(self,
7676
'X-User-Agent': f'python-sdk {SDK_VERSION}'
7777
}
7878

79-
@retry.Retry(predicate=retry.if_exception_type(
80-
labelbox.exceptions.InternalServerError))
79+
#@retry.Retry(predicate=retry.if_exception_type(
80+
# labelbox.exceptions.InternalServerError))
8181
def execute(self, query, params=None, timeout=30.0, experimental=False):
8282
""" Sends a request to the server for the execution of the
8383
given query.

labelbox/orm/db_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def __init__(self, client, field_values):
4343
"""
4444
self.client = client
4545
self._set_field_values(field_values)
46-
4746
for relationship in self.relationships():
48-
value = field_values.get(relationship.name)
47+
value = field_values.get(utils.camel_case(relationship.name))
4948
if relationship.cache and value is None:
5049
raise KeyError(
5150
f"Expected field values for {relationship.name}")
@@ -168,6 +167,7 @@ def _to_one(self):
168167
result = result and result.get(rel.graphql_name)
169168
if result is None:
170169
return None
170+
171171
return rel.destination_type(self.source.client, result)
172172

173173
def connect(self, other):

labelbox/orm/model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,43 @@ class EntityMeta(type):
242242

243243
def __init__(cls, clsname, superclasses, attributedict):
244244
super().__init__(clsname, superclasses, attributedict)
245+
cls.validate_cached_relationships()
245246
if clsname != "Entity":
246247
setattr(Entity, clsname, cls)
248+
if not hasattr(Entity, 'entities'):
249+
setattr(Entity, 'entities', [])
250+
Entity.entities.append(cls)
251+
252+
def validate_cached_relationships(cls):
253+
"""
254+
Graphql doesn't allow for infinite nesting in queries.
255+
This function checks that cached relationships result in valid queries.
256+
- A cached object and not have its own cached relationships.
257+
"""
258+
259+
cached_rels = [r for r in cls.relationships() if r.cache]
260+
print(cached_rels)
261+
# Check if any cached classes have their own cached fields
262+
for rel in cached_rels:
263+
child_name = utils.title_case(rel.name)
264+
if hasattr(Entity, child_name):
265+
for sub_rel in getattr(Entity, child_name).relationships():
266+
if sub_rel.cache:
267+
raise TypeError(
268+
"Cannot cache a relationship to an Entity with its own cached relationship(s). "
269+
f"`{utils.snake_case(cls.__name__)}` caches `{rel.name}` which caches `{sub_rel}`"
270+
)
271+
272+
# If this cls has cached fields check if any existing object caches this cls.
273+
if cached_rels:
274+
for entity in Entity.entities:
275+
attr = {rel.name: rel for rel in entity.relationships()
276+
}.get(utils.snake_case(cls.__name__))
277+
if attr and attr.cache:
278+
raise TypeError(
279+
"Cannot cache a relationship to an Entity with its own cached relationship(s). "
280+
f"`{utils.snake_case(entity.__name__)}` caches `{utils.snake_case(cls.__name__)}` which caches `{cached_rels}`"
281+
)
247282

248283

249284
class Entity(metaclass=EntityMeta):

labelbox/orm/query.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def results_query_part(entity):
4040
Args:
4141
entity (type): The entity which needs fetching.
4242
"""
43-
return " ".join(field.graphql_name for field in entity.fields())
43+
fields = [field.graphql_name for field in entity.fields()]
44+
for relationship in entity.relationships():
45+
if relationship.cache:
46+
fields.append(
47+
Query(relationship.graphql_name,
48+
relationship.destination_type).format()[0])
49+
return " ".join(fields)
4450

4551

4652
class Query:
@@ -292,9 +298,13 @@ def relationship(source, relationship, where, order_by):
292298
to_many = relationship.relationship_type == Relationship.Type.ToMany
293299
subquery = Query(relationship.graphql_name, relationship.destination_type,
294300
where, to_many, order_by)
301+
302+
303+
295304
query_where = type(source).uid == source.uid if isinstance(source, Entity) \
296305
else None
297306
query = Query(utils.camel_case(source.type_name()), subquery, query_where)
307+
298308
return query.format_top("Get" + source.type_name() +
299309
utils.title_case(relationship.graphql_name))
300310

labelbox/schema/project.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ class LabelingParameterOverride(DbObject):
641641
"""
642642
priority = Field.Int("priority")
643643
number_of_labels = Field.Int("number_of_labels")
644+
data_row = Relationship.ToOne("DataRow", cache=True)
644645

645646

646647
LabelerPerformance = namedtuple(

labelbox/schema/webhook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def delete(self):
112112
"""
113113
Deletes the webhook
114114
"""
115-
self.update(status=self.Status.INACTIVE)
115+
self.update(status=self.Status.INACTIVE.value)
116116

117117
def update(self, topics=None, url=None, status=None):
118118
""" Updates the Webhook.

tests/integration/test_labeling_parameter_overrides.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def test_labeling_parameter_overrides(project, rand_gen):
2626
assert {o.number_of_labels for o in overrides} == {3, 2, 5}
2727
assert {o.priority for o in overrides} == {4, 3, 8}
2828

29+
for override in overrides:
30+
assert isinstance(override.data_row(), DataRow)
31+
2932
success = project.unset_labeling_parameter_overrides(
3033
[data[0][0], data[1][0]])
3134
assert success

tests/integration/test_webhook.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ def test_webhook_create_update(project, rand_gen):
3838
webhook.update(topics="invalid..")
3939
assert str(exc_info.value) == \
4040
"Topics must be List[Webhook.Topic]. Found `invalid..`"
41+
42+
webhook.delete()

tests/test_entity_meta.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
3+
from labelbox.orm.model import Relationship
4+
from labelbox.orm.db_object import DbObject
5+
6+
7+
def test_illegal_cache_cond1():
8+
9+
class TestEntityA(DbObject):
10+
test_entity_b = Relationship.ToOne("TestEntityB", cache=True)
11+
12+
with pytest.raises(TypeError) as exc_info:
13+
14+
class TestEntityB(DbObject):
15+
another_entity = Relationship.ToOne("AnotherEntity", cache=True)
16+
17+
assert str(exc_info.value) == \
18+
"Cannot cache a relationship to an Entity with its own cached relationship(s)." \
19+
" `test_entity_a` caches `test_entity_b` which caches `[<Relationship: 'another_entity'>]`"
20+
21+
22+
def test_illegal_cache_cond2():
23+
24+
class TestEntityD(DbObject):
25+
another_entity = Relationship.ToOne("AnotherEntity", cache=True)
26+
27+
with pytest.raises(TypeError) as exc_info:
28+
29+
class TestEntityC(DbObject):
30+
test_entity_d = Relationship.ToOne("TestEntityD", cache=True)
31+
32+
assert str(exc_info.value) == \
33+
"Cannot cache a relationship to an Entity with its own cached relationship(s)." \
34+
" `test_entity_c` caches `test_entity_d` which caches `another_entity`"

0 commit comments

Comments
 (0)