Skip to content

Commit 8a8b7a8

Browse files
author
Matt Sokoloff
committed
Merge branch 'develop' of https://github.com/Labelbox/labelbox-python into ms/validation-part2
2 parents cabc82c + 829884a commit 8a8b7a8

File tree

13 files changed

+482
-378
lines changed

13 files changed

+482
-378
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
FROM python:3.7
22

3-
RUN pip install pytest
3+
RUN pip install pytest pytest-cases
44

55

66
WORKDIR /usr/src/labelbox

labelbox/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "labelbox"
2-
__version__ = "2.4.10"
2+
__version__ = "2.4.11"
33

44
from labelbox.client import Client
55
from labelbox.schema.bulk_import_request import BulkImportRequest
@@ -15,4 +15,4 @@
1515
from labelbox.schema.asset_metadata import AssetMetadata
1616
from labelbox.schema.webhook import Webhook
1717
from labelbox.schema.prediction import Prediction, PredictionModel
18-
from labelbox.schema.ontology import Ontology
18+
from labelbox.schema.ontology import Ontology

labelbox/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,6 @@ class UuidError(LabelboxError):
106106
pass
107107

108108

109-
class NDJsonError(LabelboxError):
110-
"""Raised when an ndjson line is invalid."""
109+
class MALValidationError(LabelboxError):
110+
"""Raised when user input is invalid for MAL imports."""
111111
...

labelbox/schema/bulk_import_request.py

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import json
2-
import logging
32
import time
4-
from pathlib import Path
53
from uuid import UUID, uuid4
6-
from pydantic import BaseModel, validator
4+
5+
import logging
6+
from pathlib import Path
77
import pydantic
88
import backoff
99
import ndjson
10-
import labelbox
1110
import requests
11+
from pydantic import BaseModel, validator
12+
from typing_extensions import Literal
13+
from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union,
14+
Type, Set)
15+
16+
import labelbox
1217
from labelbox import utils
1318
from labelbox.orm import query
1419
from labelbox.orm.db_object import DbObject
15-
from labelbox.orm.model import Field
16-
from labelbox.orm.model import Relationship
20+
from labelbox.orm.model import Field, Relationship
1721
from labelbox.schema.enums import BulkImportRequestState
18-
from pydantic import ValidationError
19-
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, Type, Set
20-
from typing_extensions import TypedDict, Literal
2122

2223
NDJSON_MIME_TYPE = "application/x-ndjson"
2324
logger = logging.getLogger(__name__)
@@ -187,6 +188,8 @@ def create_from_url(cls,
187188
project_id (str): id of project for which predictions will be imported
188189
name (str): name of BulkImportRequest
189190
url (str): publicly accessible URL pointing to ndjson file containing predictions
191+
validate (bool): a flag indicating if there should be a validation
192+
if `url` is valid ndjson
190193
Returns:
191194
BulkImportRequest object
192195
"""
@@ -219,7 +222,7 @@ def create_from_objects(cls,
219222
client,
220223
project_id: str,
221224
name: str,
222-
predictions: Iterable[dict],
225+
predictions: Iterable[Dict],
223226
validate=True) -> 'BulkImportRequest':
224227
"""
225228
Creates a `BulkImportRequest` from an iterable of dictionaries.
@@ -239,11 +242,13 @@ def create_from_objects(cls,
239242
}
240243
}``
241244
242-
Args:x
245+
Args:
243246
client (Client): a Labelbox client
244247
project_id (str): id of project for which predictions will be imported
245248
name (str): name of BulkImportRequest
246249
predictions (Iterable[dict]): iterable of dictionaries representing predictions
250+
validate (bool): a flag indicating if there should be a validation
251+
if `predictions` is valid ndjson
247252
Returns:
248253
BulkImportRequest object
249254
"""
@@ -311,7 +316,8 @@ def create_from_local_file(cls,
311316
return cls(client, response_data["createBulkImportRequest"])
312317

313318

314-
def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
319+
def _validate_ndjson(lines: Iterable[Dict[str, Any]],
320+
project: "labelbox.Project") -> None:
315321
"""
316322
Client side validation of an ndjson object.
317323
@@ -326,7 +332,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
326332
project (Project): id of project for which predictions will be imported
327333
328334
Raises:
329-
NDJsonError: Raise for invalid NDJson
335+
MALValidationError: Raise for invalid NDJson
330336
UuidError: Duplicate UUID in upload
331337
"""
332338
data_row_ids = {
@@ -345,8 +351,8 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
345351
f'{uuid} already used in this import job, '
346352
'must be unique for the project.')
347353
uids.add(uuid)
348-
except (ValidationError, ValueError, TypeError, KeyError) as e:
349-
raise labelbox.exceptions.NDJsonError(
354+
except (pydantic.ValidationError, ValueError, TypeError, KeyError) as e:
355+
raise labelbox.exceptions.MALValidationError(
350356
f"Invalid NDJson on line {idx}") from e
351357

352358

@@ -404,19 +410,19 @@ def get_mal_schemas(ontology):
404410
LabelboxID: str = pydantic.Field(..., min_length=25, max_length=25)
405411

406412

407-
class Bbox(TypedDict):
413+
class Bbox(BaseModel):
408414
top: float
409415
left: float
410416
height: float
411417
width: float
412418

413419

414-
class Point(TypedDict):
420+
class Point(BaseModel):
415421
x: float
416422
y: float
417423

418424

419-
class FrameLocation(TypedDict):
425+
class FrameLocation(BaseModel):
420426
end: int
421427
start: int
422428

@@ -426,7 +432,9 @@ class VideoSupported(BaseModel):
426432
frames: Optional[List[FrameLocation]]
427433

428434

429-
class UnionConstructor:
435+
#Base class for a special kind of union.
436+
# Compatible with pydantic. Improves error messages over a traditional union
437+
class SpecialUnion:
430438

431439
def __new__(cls, **kwargs):
432440
return cls.build(kwargs)
@@ -437,8 +445,8 @@ def __get_validators__(cls):
437445

438446
@classmethod
439447
def get_union_types(cls):
440-
if not issubclass(cls, UnionConstructor):
441-
raise TypeError("{} must be a subclass of UnionConstructor")
448+
if not issubclass(cls, SpecialUnion):
449+
raise TypeError("{} must be a subclass of SpecialUnion")
442450

443451
union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
444452
if len(union_types) < 1:
@@ -451,7 +459,16 @@ def get_union_types(cls):
451459
return union_types[0].__args__[0].__args__
452460

453461
@classmethod
454-
def build(cls: Any, data) -> "NDBase":
462+
def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase":
463+
"""
464+
Checks through all objects in the union to see which matches the input data.
465+
Args:
466+
data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union
467+
raises:
468+
KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
469+
ValidationError: Error while trying to construct a specific object in the union
470+
471+
"""
455472
if isinstance(data, BaseModel):
456473
data = data.dict()
457474

@@ -504,7 +521,6 @@ class NDFeatureSchema(BaseModel):
504521

505522
class NDBase(NDFeatureSchema):
506523
ontology_type: str
507-
schemaId: str = LabelboxID
508524
uuid: UUID
509525
dataRow: DataRow
510526

@@ -551,7 +567,7 @@ class NDText(NDBase):
551567
#No feature schema to check
552568

553569

554-
class NDCheckList(VideoSupported, NDBase):
570+
class NDChecklist(VideoSupported, NDBase):
555571
ontology_type: Literal["checklist"] = "checklist"
556572
answers: List[NDFeatureSchema] = pydantic.Field(determinant=True)
557573

@@ -564,7 +580,7 @@ def validate_answers(cls, value, field):
564580

565581
def validate_feature_schemas(self, valid_feature_schemas):
566582
#Test top level feature schema for this tool
567-
super(NDCheckList, self).validate_feature_schemas(valid_feature_schemas)
583+
super(NDChecklist, self).validate_feature_schemas(valid_feature_schemas)
568584
#Test the feature schemas provided to the answer field
569585
if len(set([answer.schemaId for answer in self.answers])) != len(
570586
self.answers):
@@ -593,9 +609,9 @@ def validate_feature_schemas(self, valid_feature_schemas):
593609

594610
#A union with custom construction logic to improve error messages
595611
class NDClassification(
596-
UnionConstructor,
612+
SpecialUnion,
597613
Type[Union[NDText, NDRadio, # type: ignore
598-
NDCheckList]]):
614+
NDChecklist]]):
599615
...
600616

601617

@@ -663,7 +679,7 @@ class NDPoint(NDBaseTool):
663679
#Could check if points are positive
664680

665681

666-
class EntityLocation(TypedDict):
682+
class EntityLocation(BaseModel):
667683
start: int
668684
end: int
669685

@@ -674,6 +690,9 @@ class NDTextEntity(NDBaseTool):
674690

675691
@validator('location')
676692
def is_valid_location(cls, v):
693+
if isinstance(v, BaseModel):
694+
v = v.dict()
695+
677696
if len(v) < 2:
678697
raise ValueError(
679698
f"A line must have at least 2 points to be valid. Found {v}")
@@ -685,7 +704,7 @@ def is_valid_location(cls, v):
685704
return v
686705

687706

688-
class MaskFeatures(TypedDict):
707+
class MaskFeatures(BaseModel):
689708
instanceURI: str
690709
colorRGB: Union[List[int], Tuple[int, int, int]]
691710

@@ -696,6 +715,9 @@ class NDMask(NDBaseTool):
696715

697716
@validator('mask')
698717
def is_valid_mask(cls, v):
718+
if isinstance(v, BaseModel):
719+
v = v.dict()
720+
699721
colors = v['colorRGB']
700722
#Does the dtype matter? Can it be a float?
701723
if not isinstance(colors, (tuple, list)):
@@ -713,15 +735,17 @@ def is_valid_mask(cls, v):
713735

714736
#A union with custom construction logic to improve error messages
715737
class NDTool(
716-
UnionConstructor,
738+
SpecialUnion,
717739
Type[Union[NDMask, # type: ignore
718740
NDTextEntity, NDPoint, NDRectangle, NDPolyline,
719741
NDPolygon,]]):
720742
...
721743

722744

723-
class NDAnnotation(UnionConstructor,
724-
Type[Union[NDTool, NDClassification]]): # type: ignore
745+
class NDAnnotation(
746+
SpecialUnion,
747+
Type[Union[NDTool, # type: ignore
748+
NDClassification]]):
725749

726750
@classmethod
727751
def build(cls: Any, data) -> "NDBase":

labelbox/schema/project.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from pathlib import Path
77
import time
8-
from typing import Union, Iterable
8+
from typing import Dict, List, Union, Iterable
99
from urllib.parse import urlparse
1010

1111
from labelbox import utils
@@ -177,6 +177,69 @@ def export_labels(self, timeout_seconds=60):
177177
self.uid)
178178
time.sleep(sleep_time)
179179

180+
def upsert_instructions(self, instructions_file: str):
181+
"""
182+
* Uploads instructions to the UI. Running more than once will replace the instructions
183+
184+
Args:
185+
instructions_file (str): Path to a local file.
186+
* Must be either a pdf, text, or html file.
187+
188+
Raises:
189+
ValueError:
190+
* project must be setup
191+
* instructions file must end with one of ".text", ".txt", ".pdf", ".html"
192+
"""
193+
194+
if self.setup_complete is None:
195+
raise ValueError(
196+
"Cannot attach instructions to a project that has not been set up."
197+
)
198+
199+
frontend = self.labeling_frontend()
200+
frontendId = frontend.uid
201+
202+
if frontend.name != "Editor":
203+
logger.warn(
204+
f"This function has only been tested to work with the Editor front end. Found %s",
205+
frontend.name)
206+
207+
supported_instruction_formats = (".text", ".txt", ".pdf", ".html")
208+
if not instructions_file.endswith(supported_instruction_formats):
209+
raise ValueError(
210+
f"instructions_file must end with one of {supported_instruction_formats}. Found {instructions_file}"
211+
)
212+
213+
lfo = list(self.labeling_frontend_options())[-1]
214+
instructions_url = self.client.upload_file(instructions_file)
215+
customization_options = json.loads(lfo.customization_options)
216+
customization_options['projectInstructions'] = instructions_url
217+
option_id = lfo.uid
218+
219+
self.client.execute(
220+
"""mutation UpdateFrontendWithExistingOptionsPyApi (
221+
$frontendId: ID!,
222+
$optionsId: ID!,
223+
$name: String!,
224+
$description: String!,
225+
$customizationOptions: String!
226+
) {
227+
updateLabelingFrontend(
228+
where: {id: $frontendId},
229+
data: {name: $name, description: $description}
230+
) {id}
231+
updateLabelingFrontendOptions(
232+
where: {id: $optionsId},
233+
data: {customizationOptions: $customizationOptions}
234+
) {id}
235+
}""", {
236+
"frontendId": frontendId,
237+
"optionsId": option_id,
238+
"name": frontend.name,
239+
"description": "Video, image, and text annotation",
240+
"customizationOptions": json.dumps(customization_options)
241+
})
242+
180243
def labeler_performance(self):
181244
""" Returns the labeler performances for this Project.
182245
@@ -486,8 +549,8 @@ def enable_model_assisted_labeling(self, toggle: bool = True) -> bool:
486549
def upload_annotations(
487550
self,
488551
name: str,
489-
annotations: Union[str, Union[str, Path], Iterable[dict]],
490-
validate=True) -> 'BulkImportRequest': # type: ignore
552+
annotations: Union[str, Path, Iterable[Dict]],
553+
validate: bool = True) -> 'BulkImportRequest': # type: ignore
491554
""" Uploads annotations to a new Editor project.
492555
493556
Args:
@@ -497,7 +560,7 @@ def upload_annotations(
497560
ndjson file
498561
OR local path to an ndjson file
499562
OR iterable of annotation rows
500-
validate (str):
563+
validate (bool):
501564
Whether or not to validate the payload before uploading.
502565
Returns:
503566
BulkImportRequest

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ requests==2.22.0
22
ndjson==0.3.1
33
backoff==1.10.0
44
google-api-core>=1.22.1
5+
pydantic==1.8

0 commit comments

Comments
 (0)