Skip to content

Commit d17d9fc

Browse files
author
Kevin Kim
committed
Return async friendly payload for both functions
1 parent 14e2338 commit d17d9fc

File tree

2 files changed

+177
-51
lines changed

2 files changed

+177
-51
lines changed

labelbox/client.py

Lines changed: 174 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -944,46 +944,79 @@ def get_model_run(self, model_run_id: str) -> ModelRun:
944944
def assign_global_keys_to_data_rows(
945945
self,
946946
global_key_to_data_row_inputs: List[Dict[str, str]],
947-
timeout_seconds=60) -> List[Dict[str, str]]:
947+
timeout_seconds=60) -> Dict[str, List[Dict[str, str]]]:
948948
"""
949-
Assigns global keys to the related data rows.
949+
Assigns global keys to data rows.
950950
951-
>>> global_key_data_row_inputs = [
952-
{"data_row_id": "cl7asgri20yvo075b4vtfedjb", "global_key": "key1"},
953-
{"data_row_id": "cl7asgri10yvg075b4pz176ht", "global_key": "key2"},
954-
]
955-
>>> client.assign_global_keys_to_data_rows(global_key_data_row_inputs)
956-
957951
Args:
958952
A list of dicts containing data_row_id and global_key.
959953
Returns:
960-
Returns successful assigned global keys and data rows
954+
Dictionary containing 'status', 'results' and 'errors'.
955+
956+
'Status' contains the outcome of this job. It can be one of
957+
'Success', 'Partial Success', or 'Failure'.
958+
959+
'Results' contains the successful global_key assignments, including
960+
global_keys that have been sanitized to Labelbox standards.
961+
962+
'Errors' contains global_key assignments that failed, along with
963+
the reasons for failure.
964+
Examples:
965+
>>> global_key_data_row_inputs = [
966+
{"data_row_id": "cl7asgri20yvo075b4vtfedjb", "global_key": "key1"},
967+
{"data_row_id": "cl7asgri10yvg075b4pz176ht", "global_key": "key2"},
968+
]
969+
>>> job_result = client.assign_global_keys_to_data_rows(global_key_data_row_inputs)
970+
>>> print(job_result['status'])
971+
Partial Success
972+
>>> print(job_result['results'])
973+
[{'data_row_id': 'cl7tv9wry00hlka6gai588ozv', 'global_key': 'gk', 'sanitized': False}]
974+
>>> print(job_result['errors'])
975+
[{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}]
961976
"""
977+
978+
def _format_successful_assignments(
979+
assignments: Dict[str, str],
980+
sanitized: bool) -> List[Dict[str, str]]:
981+
return [{
982+
'data_row_id': a['dataRowId'],
983+
'global_key': a['globalKey'],
984+
'sanitized': sanitized
985+
} for a in assignments]
986+
987+
def _format_failed_assignments(assignments: Dict[str, str],
988+
error_msg: str) -> List[Dict[str, str]]:
989+
return [{
990+
'data_row_id': a['dataRowId'],
991+
'global_key': a['globalKey'],
992+
'error': error_msg
993+
} for a in assignments]
994+
995+
# Validate input dict
962996
validation_errors = []
963997
for input in global_key_to_data_row_inputs:
964998
if "data_row_id" not in input or "global_key" not in input:
965999
validation_errors.append(input)
966-
9671000
if len(validation_errors) > 0:
9681001
raise ValueError(
9691002
f"Must provide a list of dicts containing both `data_row_id` and `global_key`. The following dict(s) are invalid: {validation_errors}."
9701003
)
9711004

1005+
# Start assign global keys to data rows job
9721006
query_str = """mutation assignGlobalKeysToDataRowsPyApi($globalKeyDataRowLinks: [AssignGlobalKeyToDataRowInput!]!) {
9731007
assignGlobalKeysToDataRows(data: {assignInputs: $globalKeyDataRowLinks}) {
9741008
jobId
9751009
}
9761010
}
9771011
"""
978-
9791012
params = {
9801013
'globalKeyDataRowLinks': [{
9811014
utils.camel_case(key): value for key, value in input.items()
9821015
} for input in global_key_to_data_row_inputs]
9831016
}
984-
9851017
assign_global_keys_to_data_rows_job = self.execute(query_str, params)
9861018

1019+
# Query string for retrieving job status and result, if job is done
9871020
result_query_str = """query assignGlobalKeysToDataRowsResultPyApi($jobId: ID!) {
9881021
assignGlobalKeysToDataRowsResult(jobId: {id: $jobId}) {
9891022
jobStatus
@@ -1011,33 +1044,52 @@ def assign_global_keys_to_data_rows(
10111044
assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows"
10121045
]["jobId"]
10131046
}
1047+
1048+
# Poll job status until finished, then retrieve results
10141049
sleep_time = 2
10151050
start_time = time.time()
10161051
while True:
10171052
res = self.execute(result_query_str, result_params)
10181053
if res["assignGlobalKeysToDataRowsResult"][
10191054
"jobStatus"] == "COMPLETE":
1020-
errors = []
1055+
results, errors = [], []
10211056
res = res['assignGlobalKeysToDataRowsResult']['data']
1022-
if res['invalidGlobalKeyAssignments']:
1023-
errors.append("Invalid Global Keys: " +
1024-
str(res['invalidGlobalKeyAssignments']))
1025-
if res['accessDeniedAssignments']:
1026-
errors.append("Access Denied Assignments: " +
1027-
str(res['accessDeniedAssignments']))
1028-
success = []
1029-
if res['sanitizedAssignments']:
1030-
success.append("Sanitized Assignments: " +
1031-
str(res['sanitizedAssignments']))
1032-
if res['unmodifiedAssignments']:
1033-
success.append("Unmodified Assignments: " +
1034-
str(res['unmodifiedAssignments']))
1057+
# Successful assignments
1058+
results.extend(
1059+
_format_successful_assignments(
1060+
assignments=res['sanitizedAssignments'],
1061+
sanitized=True))
1062+
results.extend(
1063+
_format_successful_assignments(
1064+
assignments=res['unmodifiedAssignments'],
1065+
sanitized=False))
1066+
# Failed assignments
1067+
errors.extend(
1068+
_format_failed_assignments(
1069+
assignments=res['invalidGlobalKeyAssignments'],
1070+
error_msg="Invalid global key"))
1071+
errors.extend(
1072+
_format_failed_assignments(
1073+
assignments=res['accessDeniedAssignments'],
1074+
error_msg="Access denied to Data Row"))
1075+
1076+
if len(errors) == 0:
1077+
status = "Success"
1078+
elif len(results) > 0:
1079+
status = "Partial Success"
1080+
else:
1081+
status = "Failure"
10351082

10361083
if len(errors) > 0:
1037-
raise Exception(
1038-
"Failed to assign global keys to data rows: " +
1039-
str(errors) + "\n" + str(success))
1040-
return res['sanitizedAssignments'] + res['unmodifiedAssignments']
1084+
logger.warning(
1085+
"There are errors present. Please look at 'errors' in the returned dict for more details"
1086+
)
1087+
1088+
return {
1089+
"status": status,
1090+
"results": results,
1091+
"errors": errors,
1092+
}
10411093
elif res["assignGlobalKeysToDataRowsResult"][
10421094
"jobStatus"] == "FAILED":
10431095
raise labelbox.exceptions.LabelboxError(
@@ -1049,31 +1101,51 @@ def assign_global_keys_to_data_rows(
10491101
)
10501102
time.sleep(sleep_time)
10511103

1052-
def get_data_rows_for_global_keys(
1104+
def get_data_row_ids_for_global_keys(
10531105
self,
10541106
global_keys: List[str],
10551107
timeout_seconds=60) -> List[Dict[str, str]]:
10561108
"""
1057-
Gets data rows for a list of global keys.
1058-
1059-
>>> data_rows = client.get_data_row_ids_for_global_keys(["key1",])
1109+
Gets data row ids for a list of global keys.
10601110
10611111
Args:
10621112
A list of global keys
10631113
Returns:
1064-
TODO: Better description
1114+
Dictionary containing 'status', 'results' and 'errors'.
1115+
1116+
'Status' contains the outcome of this job. It can be one of
1117+
'Success', 'Partial Success', or 'Failure'.
1118+
1119+
'Results' contains a list of data row ids successfully fetchced. It may
1120+
not necessarily contain all data rows requested.
1121+
1122+
'Errors' contains a list of global_keys that could not be fetched, along
1123+
with the failure reason
1124+
Examples:
1125+
>>> job_result = client.get_data_row_ids_for_global_keys(["key1","key2"])
1126+
>>> print(job_result['status'])
1127+
Partial Success
1128+
>>> print(job_result['results'])
1129+
['cl7tv9wry00hlka6gai588ozv', 'cl7tv9wxg00hpka6gf8sh81bj']
1130+
>>> print(job_result['errors'])
1131+
[{'global_key': 'asdf', 'error': 'Data Row not found'}]
10651132
"""
10661133

1134+
def _format_failed_retrieval(retrieval: List[str],
1135+
error_msg: str) -> List[Dict[str, str]]:
1136+
return [{'global_key': a, 'error': error_msg} for a in retrieval]
1137+
1138+
# Start get data rows for global keys job
10671139
query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) {
10681140
dataRowsForGlobalKeys(where: {ids: $globalKeys}) { jobId}}
10691141
"""
10701142
params = {"globalKeys": global_keys}
1071-
10721143
data_rows_for_global_keys_job = self.execute(query_str, params)
10731144

1145+
# Query string for retrieving job status and result, if job is done
10741146
result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) {
10751147
dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data {
1076-
fetchedDataRows {id}
1148+
fetchedDataRows { id }
10771149
notFoundGlobalKeys
10781150
accessDeniedGlobalKeys
10791151
deletedDataRowGlobalKeys
@@ -1084,36 +1156,90 @@ def get_data_rows_for_global_keys(
10841156
data_rows_for_global_keys_job["dataRowsForGlobalKeys"]["jobId"]
10851157
}
10861158

1159+
# Poll job status until finished, then retrieve results
10871160
sleep_time = 2
10881161
start_time = time.time()
10891162
while True:
10901163
res = self.execute(result_query_str, result_params)
10911164
if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE":
1092-
return res["dataRowsForGlobalKeysResult"]['data'][
1093-
'fetchedDataRows']
1165+
data = res["dataRowsForGlobalKeysResult"]['data']
1166+
results, errors = [], []
1167+
fetched_data_rows = [dr['id'] for dr in data['fetchedDataRows']]
1168+
results.extend(fetched_data_rows)
1169+
errors.extend(
1170+
_format_failed_retrieval(data['notFoundGlobalKeys'],
1171+
"Data Row not found"))
1172+
errors.extend(
1173+
_format_failed_retrieval(data['accessDeniedGlobalKeys'],
1174+
"Access denied to Data Row"))
1175+
errors.extend(
1176+
_format_failed_retrieval(data['deletedDataRowGlobalKeys'],
1177+
"Data Row deleted"))
1178+
if len(errors) == 0:
1179+
status = "Success"
1180+
elif len(results) > 0:
1181+
status = "Partial Success"
1182+
else:
1183+
status = "Failure"
1184+
1185+
if len(errors) > 0:
1186+
logger.warning(
1187+
"There are errors present. Please look at 'errors' in the returned dict for more details"
1188+
)
1189+
1190+
return {"status": status, "results": results, "errors": errors}
10941191
elif res["dataRowsForGlobalKeysResult"]['jobStatus'] == "FAILED":
10951192
raise labelbox.exceptions.LabelboxError(
1096-
"Job get_data_rows_for_global_keys failed.")
1193+
"Job dataRowsForGlobalKeys failed.")
10971194
current_time = time.time()
10981195
if current_time - start_time > timeout_seconds:
10991196
raise labelbox.exceptions.TimeoutError(
11001197
"Timed out waiting for get_data_rows_for_global_keys job to complete."
11011198
)
11021199
time.sleep(sleep_time)
11031200

1104-
def get_data_row_ids_for_global_keys(
1201+
def get_data_rows_for_global_keys(
11051202
self,
11061203
global_keys: List[str],
1107-
timeout_seconds=60) -> List[Dict[str, str]]:
1204+
timeout_seconds=60) -> List[Dict[str, List[str]]]:
11081205
"""
1109-
Gets data row ids for a list of global keys.
1110-
1111-
>>> data_row_ids = client.get_data_row_ids_for_global_keys(["key1",])
1206+
Gets data rows for a list of global keys.
11121207
11131208
Args:
11141209
A list of global keys
11151210
Returns:
1116-
TODO: Better description
1211+
Dictionary containing 'status', 'results' and 'errors'.
1212+
1213+
'Status' contains the outcome of this job. It can be one of
1214+
'Success', 'Partial Success', or 'Failure'.
1215+
1216+
'Results' contains a list of `DataRow` instances successfully fetchced. It may
1217+
not necessarily contain all data rows requested.
1218+
1219+
'Errors' contains a list of global_keys that could not be fetched, along
1220+
with the failure reason
1221+
Examples:
1222+
>>> job_result = client.get_data_rows_for_global_keys(["key1","key2"])
1223+
>>> print(job_result['status'])
1224+
Partial Success
1225+
>>> print(job_result['results'])
1226+
[<DataRow ID: cl7tvvybc00icka6ggipyh8tj>, <DataRow ID: cl7tvvyfp00igka6gblrw2idc>]
1227+
>>> print(job_result['errors'])
1228+
[{'global_key': 'asdf', 'error': 'Data Row not found'}]
11171229
"""
1118-
# TODO: Invoke get_data_rows_for_global_keys to extract data row ids
1119-
return self.get_data_rows_for_global_keys(global_keys, timeout_seconds)
1230+
job_result = self.get_data_row_ids_for_global_keys(
1231+
global_keys, timeout_seconds)
1232+
1233+
# Query for data row by data_row_id to ensure we get all fields populated in DataRow instances
1234+
data_rows = []
1235+
for data_row_id in job_result['results']:
1236+
try:
1237+
data_rows.append(self.get_data_row(data_row_id))
1238+
except labelbox.exceptions.ResourceNotFoundError:
1239+
raise labelbox.exceptions.ResourceNotFoundError(
1240+
f"Failed to fetch Data Row for id {data_row_id}. Please verify that DataRow is not deleted"
1241+
)
1242+
1243+
job_result['results'] = data_rows
1244+
1245+
return job_result

tests/integration/test_assign_global_key_to_data_row.py renamed to tests/integration/test_global_keys.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_assign_global_keys_to_data_rows(client, dataset, image_url):
2525

2626
res = client.get_data_row_ids_for_global_keys([gk_1, gk_2])
2727

28-
assert len(res) == 2
29-
successful_assignments = set(a['id'] for a in res)
28+
assert len(res['results']) == 2
29+
successful_assignments = set(res['results'])
3030
assert successful_assignments == row_ids
3131

3232

@@ -48,4 +48,4 @@ def test_assign_global_keys_to_data_rows_validation_error(client):
4848
with pytest.raises(ValueError) as excinfo:
4949
client.assign_global_keys_to_data_rows(assignment_inputs)
5050
e = """[{'data_row_id': 'test uid', 'wrong_key': 'gk 1'}, {'wrong_key': 'test uid 3', 'global_key': 'gk 3'}, {'data_row_id': 'test uid 4'}, {'global_key': 'gk 5'}, {}]"""
51-
assert e in str(excinfo.value)
51+
assert e in str(excinfo.value)

0 commit comments

Comments
 (0)