Skip to content

Commit f648522

Browse files
Use Entity metaclass for schema type registration.
1 parent 56dcf5c commit f648522

File tree

6 files changed

+33
-48
lines changed

6 files changed

+33
-48
lines changed

labelbox/orm/model.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(self, relationship_type, destination_type_name,
194194

195195
@property
196196
def destination_type(self):
197-
return Entity.named(self.destination_type_name)
197+
return getattr(Entity, self.destination_type_name)
198198

199199
def __str__(self):
200200
return self.name
@@ -203,7 +203,18 @@ def __repr__(self):
203203
return "<Relationship: %r>" % self.name
204204

205205

206-
class Entity:
206+
class EntityMeta(type):
207+
""" Entity metaclass. Registers Entity subclasses as attributes
208+
of the Entity class object so they can be referenced for example like:
209+
Entity.Project.
210+
"""
211+
def __init__(cls, clsname, superclasses, attributedict):
212+
super().__init__(clsname, superclasses, attributedict)
213+
if clsname != "Entity":
214+
setattr(Entity, clsname, cls)
215+
216+
217+
class Entity(metaclass=EntityMeta):
207218
""" An entity that contains fields and relationships. Base class
208219
for DbObject (which is base class for concrete schema classes). """
209220

@@ -279,27 +290,3 @@ def type_name(cls):
279290
Project, DataRow, ...
280291
"""
281292
return cls.__name__.split(".")[-1]
282-
283-
284-
@classmethod
285-
def named(cls, name):
286-
""" Returns an Entity (direct or indirect subclass of `Entity`) that
287-
whose class name is equal to `name`.
288-
289-
Args:
290-
name (str): Name of the sought entity, for example "Project".
291-
Return:
292-
An Entity subtype that has the given `name`.
293-
Raises:
294-
LabelboxError: if there is no such class.
295-
"""
296-
def subclasses(cls):
297-
for subclass in cls.__subclasses__():
298-
yield subclass
299-
for subsub in subclasses(subclass):
300-
yield subsub
301-
302-
for t in subclasses(Entity):
303-
if t.type_name() == name:
304-
return t
305-
raise LabelboxError("Failed to find Entity for name: %s" % name)

labelbox/schema/data_row.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def create_metadata(self, meta_type, meta_value):
5151
metaType: $%s metaValue: $%s dataRowId: $%s}) {%s}} """ % (
5252
meta_type_param, meta_value_param, data_row_id_param,
5353
meta_type_param, meta_value_param, data_row_id_param,
54-
query.results_query_part(Entity.named("AssetMetadata")))
54+
query.results_query_part(Entity.AssetMetadata))
5555

5656
res = self.client.execute(
5757
query_str, {meta_type_param: meta_type, meta_value_param: meta_value,
5858
data_row_id_param: self.uid})
59-
return Entity.named("AssetMetadata")(
59+
return Entity.AssetMetadata(
6060
self.client, res["data"]["createAssetMetadata"])

labelbox/schema/dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_data_row(self, **kwargs):
3838
any of the field names given in `data`.
3939
4040
"""
41-
DataRow = Entity.named("DataRow")
41+
DataRow = Entity.DataRow
4242
if DataRow.row_data.name not in kwargs:
4343
raise InvalidQueryError(
4444
"DataRow.row_data missing when creating DataRow.")
@@ -80,8 +80,7 @@ def create_data_rows(self, items):
8080
a DataRow.
8181
"""
8282
file_upload_thread_count = 20
83-
DataRow = Entity.named("DataRow")
84-
Task = Entity.named("Task")
83+
DataRow = Entity.DataRow
8584

8685
def upload_if_necessary(item):
8786
if isinstance(item, str):
@@ -136,12 +135,12 @@ def convert_item(item):
136135
# Fetch and return the task.
137136
task_id = res["taskId"]
138137
user = self.client.get_user()
139-
task = list(user.created_tasks(where=Task.uid == task_id))
138+
task = list(user.created_tasks(where=Entity.Task.uid == task_id))
140139
# Cache user in a private variable as the relationship can't be
141140
# resolved due to server-side limitations (see Task.created_by)
142141
# for more info.
143142
if len(task) != 1:
144-
raise ResourceNotFoundError(Task, task_id)
143+
raise ResourceNotFoundError(Entity.Task, task_id)
145144
task = task[0]
146145
task._user = user
147146
return task
@@ -161,7 +160,7 @@ def data_row_for_external_id(self, external_id):
161160
in this `DataSet` with the given external ID, or if there are
162161
multiple `DataRows` for it.
163162
"""
164-
DataRow = Entity.named("DataRow")
163+
DataRow = Entity.DataRow
165164
where = DataRow.external_id==external_id
166165

167166
data_rows = self.data_rows(where=where)

labelbox/schema/label.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ def create_review(self, **kwargs):
3838
Review attributes. At a minimum a `Review.score` field
3939
value must be provided.
4040
"""
41-
Review = Entity.named("Review")
42-
kwargs[Review.label.name] = self
43-
kwargs[Review.project.name] = self.project()
44-
return self.client._create(Review, kwargs)
41+
kwargs[Entity.Review.label.name] = self
42+
kwargs[Entity.Review.project.name] = self.project()
43+
return self.client._create(Entity.Review, kwargs)
4544

4645
def create_benchmark(self):
4746
""" Creates a Benchmark for this Label.
@@ -52,7 +51,7 @@ def create_benchmark(self):
5251
query_str = """mutation CreateBenchmarkPyApi($%s: ID!) {
5352
createBenchmark(data: {labelId: $%s}) {%s}} """ % (
5453
label_id_param, label_id_param,
55-
query.results_query_part(Entity.named("Benchmark")))
54+
query.results_query_part(Entity.Benchmark))
5655
res = self.client.execute(query_str, {label_id_param: self.uid})
5756
res = res["data"]["createBenchmark"]
58-
return Entity.named("Benchmark")(self.client, res)
57+
return Entity.Benchmark(self.client, res)

labelbox/schema/project.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def create_label(self, **kwargs):
5454
# about them. At the same time they're connected to a Label at
5555
# label creation in a non-standard way (connect via name).
5656

57-
Label = Entity.named("Label")
57+
Label = Entity.Label
5858
kwargs[Label.project] = self
5959
data = {Label.attribute(attr) if isinstance(attr, str) else attr:
6060
value.uid if isinstance(value, DbObject) else value
@@ -69,7 +69,7 @@ def create_label(self, **kwargs):
6969
return Label(self.client, res)
7070

7171
def labels(self, datasets=None, order_by=None):
72-
Label = Entity.named("Label")
72+
Label = Entity.Label
7373

7474
if datasets is not None:
7575
where = " where:{dataRow: {dataset: {id_in: [%s]}}}" % ", ".join(
@@ -135,10 +135,10 @@ def labeler_performance(self):
135135
count user {%s} secondsPerLabel totalTimeLabeling consensus
136136
averageBenchmarkAgreement lastActivityTime}
137137
}}""" % (project_id_param, project_id_param,
138-
query.results_query_part(Entity.named("User")))
138+
query.results_query_part(Entity.User))
139139

140140
def create_labeler_performance(client, result):
141-
result["user"] = Entity.named("User")(client, result["user"])
141+
result["user"] = Entity.User(client, result["user"])
142142
result["lastActivityTime"] = datetime.fromtimestamp(
143143
result["lastActivityTime"] / 1000, timezone.utc)
144144
return LabelerPerformance(**{utils.snake_case(key): value
@@ -155,7 +155,7 @@ def review_metrics(self, net_score):
155155
Return:
156156
int, aggregation count of reviews for given net_score.
157157
"""
158-
if net_score not in (None,) + tuple(Entity.named("Review").NetScore):
158+
if net_score not in (None,) + tuple(Entity.Review.NetScore):
159159
raise InvalidQueryError("Review metrics net score must be either None "
160160
"or one of Review.NetScore values")
161161
project_id_param = "project_id"
@@ -179,7 +179,7 @@ def setup(self, labeling_frontend, labeling_frontend_options):
179179
if not isinstance(labeling_frontend_options, str):
180180
labeling_frontend_options = json.dumps(labeling_frontend_options)
181181

182-
LFO = Entity.named("LabelingFrontendOptions")
182+
LFO = Entity.LabelingFrontendOptions
183183
labeling_frontend_options = self.client._create(
184184
LFO, {LFO.project: self, LFO.labeling_frontend: labeling_frontend,
185185
LFO.customization_options: labeling_frontend_options,

labelbox/schema/webhook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create(client, topics, url, secret, project):
4848
query_str = """mutation CreateWebhookPyApi {
4949
createWebhook(data:{%s topics:{set:[%s]}, url:"%s", secret:"%s" }){%s}
5050
} """ % (project_str, " ".join(topics), url, secret,
51-
query.results_query_part(Entity.named("Webhook")))
51+
query.results_query_part(Entity.Webhook))
5252

5353
res = client.execute(query_str)
5454
return Webhook(client, res["data"]["createWebhook"])
@@ -74,7 +74,7 @@ def update(self, topics=None, url=None, status=None):
7474
query_str = """mutation UpdateWebhookPyApi {
7575
updateWebhook(where: {id: "%s"} data:{%s}){%s}} """ % (
7676
self.uid, ", ".join(filter(None, (topics_str, url_str, status_str))),
77-
query.results_query_part(Entity.named("Webhook")))
77+
query.results_query_part(Entity.Webhook))
7878

7979
res = self.client.execute(query_str)
8080
res = res["data"]["updateWebhook"]

0 commit comments

Comments
 (0)