|
1 | | -# type: ignore |
2 | | -from datetime import datetime, timezone |
3 | 1 | import json |
4 | | -from typing import List, Dict |
5 | | -from collections import defaultdict |
6 | | - |
7 | 2 | import logging |
8 | 3 | import mimetypes |
9 | 4 | import os |
| 5 | +from collections import defaultdict |
| 6 | +from datetime import datetime, timezone |
| 7 | +from typing import List, Dict |
10 | 8 |
|
11 | | -from google.api_core import retry |
12 | 9 | import requests |
13 | 10 | import requests.exceptions |
| 11 | +from google.api_core import retry |
14 | 12 |
|
15 | 13 | import labelbox.exceptions |
16 | | -from labelbox import utils |
17 | 14 | from labelbox import __version__ as SDK_VERSION |
| 15 | +from labelbox import utils |
18 | 16 | from labelbox.orm import query |
19 | 17 | from labelbox.orm.db_object import DbObject |
20 | 18 | from labelbox.orm.model import Entity |
21 | 19 | from labelbox.pagination import PaginatedCollection |
| 20 | +from labelbox.schema import role |
22 | 21 | from labelbox.schema.data_row_metadata import DataRowMetadataOntology |
23 | 22 | from labelbox.schema.iam_integration import IAMIntegration |
24 | | -from labelbox.schema import role |
25 | 23 | from labelbox.schema.ontology import Tool, Classification |
26 | 24 |
|
27 | 25 | logger = logging.getLogger(__name__) |
@@ -354,7 +352,7 @@ def upload_data(self, |
354 | 352 | data=request_data, |
355 | 353 | files={ |
356 | 354 | "1": (filename, content, content_type) if |
357 | | - (filename and content_type) else content |
| 355 | + (filename and content_type) else content |
358 | 356 | }) |
359 | 357 |
|
360 | 358 | if response.status_code == 502: |
@@ -518,7 +516,7 @@ def _create(self, db_object_type, data): |
518 | 516 | # Also convert Labelbox object values to their UIDs. |
519 | 517 | data = { |
520 | 518 | db_object_type.attribute(attr) if isinstance(attr, str) else attr: |
521 | | - value.uid if isinstance(value, DbObject) else value |
| 519 | + value.uid if isinstance(value, DbObject) else value |
522 | 520 | for attr, value in data.items() |
523 | 521 | } |
524 | 522 |
|
@@ -702,8 +700,8 @@ def get_data_row_ids_for_external_ids( |
702 | 700 | for i in range(0, len(external_ids), max_ids_per_request): |
703 | 701 | for row in self.execute( |
704 | 702 | query_str, |
705 | | - {'externalId_in': external_ids[i:i + max_ids_per_request] |
706 | | - })['externalIdsToDataRowIds']: |
| 703 | + {'externalId_in': external_ids[i:i + max_ids_per_request] |
| 704 | + })['externalIdsToDataRowIds']: |
707 | 705 | result[row['externalId']].append(row['dataRowId']) |
708 | 706 | return result |
709 | 707 |
|
@@ -896,3 +894,57 @@ def create_feature_schema(self, normalized): |
896 | 894 | # But the features are the same so we just grab the feature schema id |
897 | 895 | res['id'] = res['normalized']['featureSchemaId'] |
898 | 896 | return Entity.FeatureSchema(self, res) |
| 897 | + |
| 898 | + def get_batch(self, batch_id: str): |
| 899 | + """Gets a single Batch using its ID |
| 900 | +
|
| 901 | + Args: |
| 902 | + batch_id: Id of the batch |
| 903 | +
|
| 904 | + Returns: |
| 905 | + The sought Batch |
| 906 | + """ |
| 907 | + |
| 908 | + return self._get_single(Entity.Batch, batch_id) |
| 909 | + |
| 910 | + def create_batch(self, name: str, project, data_rows: List[str], priority: int): |
| 911 | + """Create a batch of data rows to send to a project |
| 912 | +
|
| 913 | + >>> data_rows = ['<data-row-id>', ...] |
| 914 | + >>> project = client.get("<project-id>") |
| 915 | + >>> client.create_batch(name="low-confidence-images", project=project, data_rows=data_rows) |
| 916 | +
|
| 917 | + Args: |
| 918 | + name: Descriptive name for the batch, must be unique per project |
| 919 | + project: The project to send the batch to |
| 920 | + data_rows: A list of data rows ids |
| 921 | + priority: the default priority for the datarows, lowest priority by default |
| 922 | +
|
| 923 | + Returns: |
| 924 | + The created batch |
| 925 | + """ |
| 926 | + |
| 927 | + if isinstance(project, Entity.Project): |
| 928 | + project_id = project.uid |
| 929 | + elif isinstance(project, str): |
| 930 | + project_id = project |
| 931 | + else: |
| 932 | + raise ValueError("You must pass a project id or a Project") |
| 933 | + |
| 934 | + data_row_ids = [] |
| 935 | + for dr in data_rows: |
| 936 | + pass |
| 937 | + |
| 938 | + query_str = """mutation createBatchPyApi($name: String!, $dataRowIds: [ID!]!, $priority: Int!){ |
| 939 | + createBatch(){ |
| 940 | + %s |
| 941 | + } |
| 942 | + }""" |
| 943 | + |
| 944 | + result = self.execute(query_str, { |
| 945 | + "name": name, |
| 946 | + "projectId": project_id, |
| 947 | + "dataRowIds": data_row_ids, |
| 948 | + "priority": priority |
| 949 | + }) |
| 950 | + return Entity.Batch(self, result['createModel']) |
0 commit comments