Skip to content

Commit e8f4d22

Browse files
author
Kevin Kim
committed
Addressed review comments
1 parent 19fd4df commit e8f4d22

File tree

3 files changed

+81
-55
lines changed

3 files changed

+81
-55
lines changed

labelbox/client.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# type: ignore
22
from datetime import datetime, timezone
33
import json
4-
from typing import List, Dict
4+
from typing import Any, List, Dict, Union
55
from collections import defaultdict
66

77
import logging
@@ -20,8 +20,10 @@
2020
from labelbox.orm.db_object import DbObject
2121
from labelbox.orm.model import Entity
2222
from labelbox.pagination import PaginatedCollection
23+
from labelbox.schema.data_row import DataRow
2324
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
2425
from labelbox.schema.dataset import Dataset
26+
from labelbox.schema.enums import TaskResult
2527
from labelbox.schema.iam_integration import IAMIntegration
2628
from labelbox.schema import role
2729
from labelbox.schema.labeling_frontend import LabelingFrontend
@@ -944,7 +946,7 @@ def get_model_run(self, model_run_id: str) -> ModelRun:
944946
def assign_global_keys_to_data_rows(
945947
self,
946948
global_key_to_data_row_inputs: List[Dict[str, str]],
947-
timeout_seconds=60) -> Dict[str, List[Dict[str, str]]]:
949+
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
948950
"""
949951
Assigns global keys to data rows.
950952
@@ -975,22 +977,21 @@ def assign_global_keys_to_data_rows(
975977
[{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}]
976978
"""
977979

978-
def _format_successful_assignments(
979-
assignments: Dict[str, str],
980-
sanitized: bool) -> List[Dict[str, str]]:
980+
def _format_successful_rows(rows: Dict[str, str],
981+
sanitized: bool) -> List[Dict[str, str]]:
981982
return [{
982-
'data_row_id': a['dataRowId'],
983-
'global_key': a['globalKey'],
983+
'data_row_id': r['dataRowId'],
984+
'global_key': r['globalKey'],
984985
'sanitized': sanitized
985-
} for a in assignments]
986+
} for r in rows]
986987

987-
def _format_failed_assignments(assignments: Dict[str, str],
988-
error_msg: str) -> List[Dict[str, str]]:
988+
def _format_failed_rows(rows: Dict[str, str],
989+
error_msg: str) -> List[Dict[str, str]]:
989990
return [{
990-
'data_row_id': a['dataRowId'],
991-
'global_key': a['globalKey'],
991+
'data_row_id': r['dataRowId'],
992+
'global_key': r['globalKey'],
992993
'error': error_msg
993-
} for a in assignments]
994+
} for r in rows]
994995

995996
# Validate input dict
996997
validation_errors = []
@@ -1056,31 +1057,27 @@ def _format_failed_assignments(assignments: Dict[str, str],
10561057
res = res['assignGlobalKeysToDataRowsResult']['data']
10571058
# Successful assignments
10581059
results.extend(
1059-
_format_successful_assignments(
1060-
assignments=res['sanitizedAssignments'],
1061-
sanitized=True))
1060+
_format_successful_rows(rows=res['sanitizedAssignments'],
1061+
sanitized=True))
10621062
results.extend(
1063-
_format_successful_assignments(
1064-
assignments=res['unmodifiedAssignments'],
1065-
sanitized=False))
1063+
_format_successful_rows(rows=res['unmodifiedAssignments'],
1064+
sanitized=False))
10661065
# Failed assignments
10671066
errors.extend(
1068-
_format_failed_assignments(
1069-
assignments=res['invalidGlobalKeyAssignments'],
1070-
error_msg="Invalid global key"))
1067+
_format_failed_rows(rows=res['invalidGlobalKeyAssignments'],
1068+
error_msg="Invalid global key"))
10711069
errors.extend(
1072-
_format_failed_assignments(
1073-
assignments=res['accessDeniedAssignments'],
1074-
error_msg="Access denied to Data Row"))
1075-
1076-
if len(errors) == 0:
1077-
status = "Success"
1078-
elif len(results) > 0:
1079-
status = "Partial Success"
1070+
_format_failed_rows(rows=res['accessDeniedAssignments'],
1071+
error_msg="Access denied to Data Row"))
1072+
1073+
if not errors:
1074+
status = TaskResult.SUCCESS.value
1075+
elif errors and results:
1076+
status = TaskResult.PARTIAL_SUCCESS.value
10801077
else:
1081-
status = "Failure"
1078+
status = TaskResult.FAILURE.value
10821079

1083-
if len(errors) > 0:
1080+
if errors:
10841081
logger.warning(
10851082
"There are errors present. Please look at 'errors' in the returned dict for more details"
10861083
)
@@ -1104,7 +1101,7 @@ def _format_failed_assignments(assignments: Dict[str, str],
11041101
def get_data_row_ids_for_global_keys(
11051102
self,
11061103
global_keys: List[str],
1107-
timeout_seconds=60) -> List[Dict[str, str]]:
1104+
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
11081105
"""
11091106
Gets data row ids for a list of global keys.
11101107
@@ -1131,9 +1128,16 @@ def get_data_row_ids_for_global_keys(
11311128
[{'global_key': 'asdf', 'error': 'Data Row not found'}]
11321129
"""
11331130

1134-
def _format_failed_retrieval(retrieval: List[str],
1135-
error_msg: str) -> List[Dict[str, str]]:
1136-
return [{'global_key': a, 'error': error_msg} for a in retrieval]
1131+
def _format_successful_rows(
1132+
rows: List[Dict[str, str]]) -> List[Dict[str, str]]:
1133+
return [{
1134+
'data_row_id': r['id'],
1135+
'global_key': r['globalKey']
1136+
} for r in rows]
1137+
1138+
def _format_failed_rows(rows: List[str],
1139+
error_msg: str) -> List[Dict[str, str]]:
1140+
return [{'global_key': r, 'error': error_msg} for r in rows]
11371141

11381142
# Start get data rows for global keys job
11391143
query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) {
@@ -1145,7 +1149,7 @@ def _format_failed_retrieval(retrieval: List[str],
11451149
# Query string for retrieving job status and result, if job is done
11461150
result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) {
11471151
dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data {
1148-
fetchedDataRows { id }
1152+
fetchedDataRows { id globalKey }
11491153
notFoundGlobalKeys
11501154
accessDeniedGlobalKeys
11511155
deletedDataRowGlobalKeys
@@ -1164,25 +1168,24 @@ def _format_failed_retrieval(retrieval: List[str],
11641168
if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE":
11651169
data = res["dataRowsForGlobalKeysResult"]['data']
11661170
results, errors = [], []
1167-
fetched_data_rows = [dr['id'] for dr in data['fetchedDataRows']]
1168-
results.extend(fetched_data_rows)
1171+
results.extend(_format_successful_rows(data['fetchedDataRows']))
11691172
errors.extend(
1170-
_format_failed_retrieval(data['notFoundGlobalKeys'],
1171-
"Data Row not found"))
1173+
_format_failed_rows(data['notFoundGlobalKeys'],
1174+
"Data Row not found"))
11721175
errors.extend(
1173-
_format_failed_retrieval(data['accessDeniedGlobalKeys'],
1174-
"Access denied to Data Row"))
1176+
_format_failed_rows(data['accessDeniedGlobalKeys'],
1177+
"Access denied to Data Row"))
11751178
errors.extend(
1176-
_format_failed_retrieval(data['deletedDataRowGlobalKeys'],
1177-
"Data Row deleted"))
1178-
if len(errors) == 0:
1179-
status = "Success"
1180-
elif len(results) > 0:
1181-
status = "Partial Success"
1179+
_format_failed_rows(data['deletedDataRowGlobalKeys'],
1180+
"Data Row deleted"))
1181+
if not errors:
1182+
status = TaskResult.SUCCESS.value
1183+
elif errors and results:
1184+
status = TaskResult.PARTIAL_SUCCESS.value
11821185
else:
1183-
status = "Failure"
1186+
status = TaskResult.FAILURE.value
11841187

1185-
if len(errors) > 0:
1188+
if errors:
11861189
logger.warning(
11871190
"There are errors present. Please look at 'errors' in the returned dict for more details"
11881191
)
@@ -1201,7 +1204,7 @@ def _format_failed_retrieval(retrieval: List[str],
12011204
def get_data_rows_for_global_keys(
12021205
self,
12031206
global_keys: List[str],
1204-
timeout_seconds=60) -> List[Dict[str, List[str]]]:
1207+
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
12051208
"""
12061209
Gets data rows for a list of global keys.
12071210
@@ -1232,12 +1235,12 @@ def get_data_rows_for_global_keys(
12321235

12331236
# Query for data row by data_row_id to ensure we get all fields populated in DataRow instances
12341237
data_rows = []
1235-
for data_row_id in job_result['results']:
1238+
for data_row in job_result['results']:
12361239
try:
1237-
data_rows.append(self.get_data_row(data_row_id))
1240+
data_rows.append(self.get_data_row(data_row['data_row_id']))
12381241
except labelbox.exceptions.ResourceNotFoundError:
12391242
raise labelbox.exceptions.ResourceNotFoundError(
1240-
f"Failed to fetch Data Row for id {data_row_id}. Please verify that DataRow is not deleted"
1243+
f"Failed to fetch Data Row for id {data_row['data_row_id']}. Please verify that DataRow is not deleted"
12411244
)
12421245

12431246
job_result['results'] = data_rows

labelbox/schema/enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ class AnnotationImportState(Enum):
4444
RUNNING = "RUNNING"
4545
FAILED = "FAILED"
4646
FINISHED = "FINISHED"
47+
48+
49+
class TaskResult(Enum):
50+
SUCCESS = "Success"
51+
PARTIAL_SUCCESS = "Partial Success"
52+
FAILURE = "Failure"

labelbox/schema/task_result.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Dict, List
2+
3+
4+
class BulkTaskResult:
5+
status: str = None
6+
success: List[Dict[str, str]] = None
7+
errors: Dict[str, List[str]] = None
8+
9+
def __init__(self, status, success, errors):
10+
self.status = status
11+
self.success = success
12+
self.errors = errors
13+
14+
def __repr__(self) -> str:
15+
return f"BulkTaskResult(status={self.status}," \
16+
f"success={self.success}," \
17+
f"errors={self.errors})"

0 commit comments

Comments
 (0)