|
1 | 1 | # type: ignore |
2 | 2 | from datetime import datetime, timezone |
3 | 3 | import json |
4 | | -from typing import List, Dict |
| 4 | +from typing import Any, List, Dict, Union |
5 | 5 | from collections import defaultdict |
6 | 6 |
|
7 | 7 | import logging |
8 | 8 | import mimetypes |
9 | 9 | import os |
| 10 | +import time |
10 | 11 |
|
11 | 12 | from google.api_core import retry |
12 | 13 | import requests |
|
21 | 22 | from labelbox.pagination import PaginatedCollection |
22 | 23 | from labelbox.schema.data_row_metadata import DataRowMetadataOntology |
23 | 24 | from labelbox.schema.dataset import Dataset |
| 25 | +from labelbox.schema.enums import CollectionJobStatus |
24 | 26 | from labelbox.schema.iam_integration import IAMIntegration |
25 | 27 | from labelbox.schema import role |
26 | 28 | from labelbox.schema.labeling_frontend import LabelingFrontend |
@@ -939,3 +941,296 @@ def get_model_run(self, model_run_id: str) -> ModelRun: |
939 | 941 | A ModelRun object. |
940 | 942 | """ |
941 | 943 | return self._get_single(Entity.ModelRun, model_run_id) |
| 944 | + |
| 945 | + def assign_global_keys_to_data_rows( |
| 946 | + self, |
| 947 | + global_key_to_data_row_inputs: List[Dict[str, str]], |
| 948 | + timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: |
| 949 | + """ |
| 950 | + Assigns global keys to data rows. |
| 951 | + |
| 952 | + Args: |
| 953 | + A list of dicts containing data_row_id and global_key. |
| 954 | + Returns: |
| 955 | + Dictionary containing 'status', 'results' and 'errors'. |
| 956 | +
|
| 957 | + 'Status' contains the outcome of this job. It can be one of |
| 958 | + 'Success', 'Partial Success', or 'Failure'. |
| 959 | +
|
| 960 | + 'Results' contains the successful global_key assignments, including |
| 961 | + global_keys that have been sanitized to Labelbox standards. |
| 962 | +
|
| 963 | + 'Errors' contains global_key assignments that failed, along with |
| 964 | + the reasons for failure. |
| 965 | + Examples: |
| 966 | + >>> global_key_data_row_inputs = [ |
| 967 | + {"data_row_id": "cl7asgri20yvo075b4vtfedjb", "global_key": "key1"}, |
| 968 | + {"data_row_id": "cl7asgri10yvg075b4pz176ht", "global_key": "key2"}, |
| 969 | + ] |
| 970 | + >>> job_result = client.assign_global_keys_to_data_rows(global_key_data_row_inputs) |
| 971 | + >>> print(job_result['status']) |
| 972 | + Partial Success |
| 973 | + >>> print(job_result['results']) |
| 974 | + [{'data_row_id': 'cl7tv9wry00hlka6gai588ozv', 'global_key': 'gk', 'sanitized': False}] |
| 975 | + >>> print(job_result['errors']) |
| 976 | + [{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}] |
| 977 | + """ |
| 978 | + |
| 979 | + def _format_successful_rows(rows: Dict[str, str], |
| 980 | + sanitized: bool) -> List[Dict[str, str]]: |
| 981 | + return [{ |
| 982 | + 'data_row_id': r['dataRowId'], |
| 983 | + 'global_key': r['globalKey'], |
| 984 | + 'sanitized': sanitized |
| 985 | + } for r in rows] |
| 986 | + |
| 987 | + def _format_failed_rows(rows: Dict[str, str], |
| 988 | + error_msg: str) -> List[Dict[str, str]]: |
| 989 | + return [{ |
| 990 | + 'data_row_id': r['dataRowId'], |
| 991 | + 'global_key': r['globalKey'], |
| 992 | + 'error': error_msg |
| 993 | + } for r in rows] |
| 994 | + |
| 995 | + # Validate input dict |
| 996 | + validation_errors = [] |
| 997 | + for input in global_key_to_data_row_inputs: |
| 998 | + if "data_row_id" not in input or "global_key" not in input: |
| 999 | + validation_errors.append(input) |
| 1000 | + if len(validation_errors) > 0: |
| 1001 | + raise ValueError( |
| 1002 | + f"Must provide a list of dicts containing both `data_row_id` and `global_key`. The following dict(s) are invalid: {validation_errors}." |
| 1003 | + ) |
| 1004 | + |
| 1005 | + # Start assign global keys to data rows job |
| 1006 | + query_str = """mutation assignGlobalKeysToDataRowsPyApi($globalKeyDataRowLinks: [AssignGlobalKeyToDataRowInput!]!) { |
| 1007 | + assignGlobalKeysToDataRows(data: {assignInputs: $globalKeyDataRowLinks}) { |
| 1008 | + jobId |
| 1009 | + } |
| 1010 | + } |
| 1011 | + """ |
| 1012 | + params = { |
| 1013 | + 'globalKeyDataRowLinks': [{ |
| 1014 | + utils.camel_case(key): value for key, value in input.items() |
| 1015 | + } for input in global_key_to_data_row_inputs] |
| 1016 | + } |
| 1017 | + assign_global_keys_to_data_rows_job = self.execute(query_str, params) |
| 1018 | + |
| 1019 | + # Query string for retrieving job status and result, if job is done |
| 1020 | + result_query_str = """query assignGlobalKeysToDataRowsResultPyApi($jobId: ID!) { |
| 1021 | + assignGlobalKeysToDataRowsResult(jobId: {id: $jobId}) { |
| 1022 | + jobStatus |
| 1023 | + data { |
| 1024 | + sanitizedAssignments { |
| 1025 | + dataRowId |
| 1026 | + globalKey |
| 1027 | + } |
| 1028 | + invalidGlobalKeyAssignments { |
| 1029 | + dataRowId |
| 1030 | + globalKey |
| 1031 | + } |
| 1032 | + unmodifiedAssignments { |
| 1033 | + dataRowId |
| 1034 | + globalKey |
| 1035 | + } |
| 1036 | + accessDeniedAssignments { |
| 1037 | + dataRowId |
| 1038 | + globalKey |
| 1039 | + } |
| 1040 | + }}} |
| 1041 | + """ |
| 1042 | + result_params = { |
| 1043 | + "jobId": |
| 1044 | + assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows" |
| 1045 | + ]["jobId"] |
| 1046 | + } |
| 1047 | + |
| 1048 | + # Poll job status until finished, then retrieve results |
| 1049 | + sleep_time = 2 |
| 1050 | + start_time = time.time() |
| 1051 | + while True: |
| 1052 | + res = self.execute(result_query_str, result_params) |
| 1053 | + if res["assignGlobalKeysToDataRowsResult"][ |
| 1054 | + "jobStatus"] == "COMPLETE": |
| 1055 | + results, errors = [], [] |
| 1056 | + res = res['assignGlobalKeysToDataRowsResult']['data'] |
| 1057 | + # Successful assignments |
| 1058 | + results.extend( |
| 1059 | + _format_successful_rows(rows=res['sanitizedAssignments'], |
| 1060 | + sanitized=True)) |
| 1061 | + results.extend( |
| 1062 | + _format_successful_rows(rows=res['unmodifiedAssignments'], |
| 1063 | + sanitized=False)) |
| 1064 | + # Failed assignments |
| 1065 | + errors.extend( |
| 1066 | + _format_failed_rows(rows=res['invalidGlobalKeyAssignments'], |
| 1067 | + error_msg="Invalid global key")) |
| 1068 | + errors.extend( |
| 1069 | + _format_failed_rows(rows=res['accessDeniedAssignments'], |
| 1070 | + error_msg="Access denied to Data Row")) |
| 1071 | + |
| 1072 | + if not errors: |
| 1073 | + status = CollectionJobStatus.SUCCESS.value |
| 1074 | + elif errors and results: |
| 1075 | + status = CollectionJobStatus.PARTIAL_SUCCESS.value |
| 1076 | + else: |
| 1077 | + status = CollectionJobStatus.FAILURE.value |
| 1078 | + |
| 1079 | + if errors: |
| 1080 | + logger.warning( |
| 1081 | + "There are errors present. Please look at 'errors' in the returned dict for more details" |
| 1082 | + ) |
| 1083 | + |
| 1084 | + return { |
| 1085 | + "status": status, |
| 1086 | + "results": results, |
| 1087 | + "errors": errors, |
| 1088 | + } |
| 1089 | + elif res["assignGlobalKeysToDataRowsResult"][ |
| 1090 | + "jobStatus"] == "FAILED": |
| 1091 | + raise labelbox.exceptions.LabelboxError( |
| 1092 | + "Job assign_global_keys_to_data_rows failed.") |
| 1093 | + current_time = time.time() |
| 1094 | + if current_time - start_time > timeout_seconds: |
| 1095 | + raise labelbox.exceptions.TimeoutError( |
| 1096 | + "Timed out waiting for assign_global_keys_to_data_rows job to complete." |
| 1097 | + ) |
| 1098 | + time.sleep(sleep_time) |
| 1099 | + |
| 1100 | + def get_data_row_ids_for_global_keys( |
| 1101 | + self, |
| 1102 | + global_keys: List[str], |
| 1103 | + timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: |
| 1104 | + """ |
| 1105 | + Gets data row ids for a list of global keys. |
| 1106 | +
|
| 1107 | + Args: |
| 1108 | + A list of global keys |
| 1109 | + Returns: |
| 1110 | + Dictionary containing 'status', 'results' and 'errors'. |
| 1111 | +
|
| 1112 | + 'Status' contains the outcome of this job. It can be one of |
| 1113 | + 'Success', 'Partial Success', or 'Failure'. |
| 1114 | +
|
| 1115 | + 'Results' contains a list of data row ids successfully fetchced. It may |
| 1116 | + not necessarily contain all data rows requested. |
| 1117 | +
|
| 1118 | + 'Errors' contains a list of global_keys that could not be fetched, along |
| 1119 | + with the failure reason |
| 1120 | + Examples: |
| 1121 | + >>> job_result = client.get_data_row_ids_for_global_keys(["key1","key2"]) |
| 1122 | + >>> print(job_result['status']) |
| 1123 | + Partial Success |
| 1124 | + >>> print(job_result['results']) |
| 1125 | + ['cl7tv9wry00hlka6gai588ozv', 'cl7tv9wxg00hpka6gf8sh81bj'] |
| 1126 | + >>> print(job_result['errors']) |
| 1127 | + [{'global_key': 'asdf', 'error': 'Data Row not found'}] |
| 1128 | + """ |
| 1129 | + |
| 1130 | + def _format_failed_rows(rows: List[str], |
| 1131 | + error_msg: str) -> List[Dict[str, str]]: |
| 1132 | + return [{'global_key': r, 'error': error_msg} for r in rows] |
| 1133 | + |
| 1134 | + # Start get data rows for global keys job |
| 1135 | + query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) { |
| 1136 | + dataRowsForGlobalKeys(where: {ids: $globalKeys}) { jobId}} |
| 1137 | + """ |
| 1138 | + params = {"globalKeys": global_keys} |
| 1139 | + data_rows_for_global_keys_job = self.execute(query_str, params) |
| 1140 | + |
| 1141 | + # Query string for retrieving job status and result, if job is done |
| 1142 | + result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) { |
| 1143 | + dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data { |
| 1144 | + fetchedDataRows { id } |
| 1145 | + notFoundGlobalKeys |
| 1146 | + accessDeniedGlobalKeys |
| 1147 | + deletedDataRowGlobalKeys |
| 1148 | + } jobStatus}} |
| 1149 | + """ |
| 1150 | + result_params = { |
| 1151 | + "jobId": |
| 1152 | + data_rows_for_global_keys_job["dataRowsForGlobalKeys"]["jobId"] |
| 1153 | + } |
| 1154 | + |
| 1155 | + # Poll job status until finished, then retrieve results |
| 1156 | + sleep_time = 2 |
| 1157 | + start_time = time.time() |
| 1158 | + while True: |
| 1159 | + res = self.execute(result_query_str, result_params) |
| 1160 | + if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE": |
| 1161 | + data = res["dataRowsForGlobalKeysResult"]['data'] |
| 1162 | + results, errors = [], [] |
| 1163 | + results.extend([row['id'] for row in data['fetchedDataRows']]) |
| 1164 | + errors.extend( |
| 1165 | + _format_failed_rows(data['notFoundGlobalKeys'], |
| 1166 | + "Data Row not found")) |
| 1167 | + errors.extend( |
| 1168 | + _format_failed_rows(data['accessDeniedGlobalKeys'], |
| 1169 | + "Access denied to Data Row")) |
| 1170 | + errors.extend( |
| 1171 | + _format_failed_rows(data['deletedDataRowGlobalKeys'], |
| 1172 | + "Data Row deleted")) |
| 1173 | + if not errors: |
| 1174 | + status = CollectionJobStatus.SUCCESS.value |
| 1175 | + elif errors and results: |
| 1176 | + status = CollectionJobStatus.PARTIAL_SUCCESS.value |
| 1177 | + else: |
| 1178 | + status = CollectionJobStatus.FAILURE.value |
| 1179 | + |
| 1180 | + if errors: |
| 1181 | + logger.warning( |
| 1182 | + "There are errors present. Please look at 'errors' in the returned dict for more details" |
| 1183 | + ) |
| 1184 | + |
| 1185 | + return {"status": status, "results": results, "errors": errors} |
| 1186 | + elif res["dataRowsForGlobalKeysResult"]['jobStatus'] == "FAILED": |
| 1187 | + raise labelbox.exceptions.LabelboxError( |
| 1188 | + "Job dataRowsForGlobalKeys failed.") |
| 1189 | + current_time = time.time() |
| 1190 | + if current_time - start_time > timeout_seconds: |
| 1191 | + raise labelbox.exceptions.TimeoutError( |
| 1192 | + "Timed out waiting for get_data_rows_for_global_keys job to complete." |
| 1193 | + ) |
| 1194 | + time.sleep(sleep_time) |
| 1195 | + |
| 1196 | + def get_data_rows_for_global_keys( |
| 1197 | + self, |
| 1198 | + global_keys: List[str], |
| 1199 | + timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: |
| 1200 | + """ |
| 1201 | + Gets data rows for a list of global keys. |
| 1202 | +
|
| 1203 | + Args: |
| 1204 | + A list of global keys |
| 1205 | + Returns: |
| 1206 | + Dictionary containing 'status', 'results' and 'errors'. |
| 1207 | +
|
| 1208 | + 'Status' contains the outcome of this job. It can be one of |
| 1209 | + 'Success', 'Partial Success', or 'Failure'. |
| 1210 | +
|
| 1211 | + 'Results' contains a list of `DataRow` instances successfully fetchced. It may |
| 1212 | + not necessarily contain all data rows requested. |
| 1213 | +
|
| 1214 | + 'Errors' contains a list of global_keys that could not be fetched, along |
| 1215 | + with the failure reason |
| 1216 | + Examples: |
| 1217 | + >>> job_result = client.get_data_rows_for_global_keys(["key1","key2"]) |
| 1218 | + >>> print(job_result['status']) |
| 1219 | + Partial Success |
| 1220 | + >>> print(job_result['results']) |
| 1221 | + [<DataRow ID: cl7tvvybc00icka6ggipyh8tj>, <DataRow ID: cl7tvvyfp00igka6gblrw2idc>] |
| 1222 | + >>> print(job_result['errors']) |
| 1223 | + [{'global_key': 'asdf', 'error': 'Data Row not found'}] |
| 1224 | + """ |
| 1225 | + job_result = self.get_data_row_ids_for_global_keys( |
| 1226 | + global_keys, timeout_seconds) |
| 1227 | + |
| 1228 | + # Query for data row by data_row_id to ensure we get all fields populated in DataRow instances |
| 1229 | + data_rows = [] |
| 1230 | + for data_row_id in job_result['results']: |
| 1231 | + # TODO: Need to optimize this to run over a collection of data_row_ids |
| 1232 | + data_rows.append(self.get_data_row(data_row_id)) |
| 1233 | + |
| 1234 | + job_result['results'] = data_rows |
| 1235 | + |
| 1236 | + return job_result |
0 commit comments