Skip to content

Commit a746812

Browse files
author
Matt Sokoloff
committed
recommended changes
1 parent e1012f5 commit a746812

File tree

3 files changed

+74
-52
lines changed

3 files changed

+74
-52
lines changed

labelbox/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,6 @@ class UuidError(LabelboxError):
106106
pass
107107

108108

109-
class ValidationError(LabelboxError):
109+
class MALValidationError(LabelboxError):
110110
"""Raised when user input is invalid for MAL imports."""
111111
...

labelbox/schema/bulk_import_request.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import ndjson
1010
import requests
1111
from pydantic import BaseModel, validator
12-
from pydantic import ValidationError
13-
from typing_extensions import TypedDict, Literal
12+
from typing_extensions import Literal
1413
from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union,
1514
Type, Set)
1615

@@ -189,6 +188,8 @@ def create_from_url(cls,
189188
project_id (str): id of project for which predictions will be imported
190189
name (str): name of BulkImportRequest
191190
url (str): publicly accessible URL pointing to ndjson file containing predictions
191+
validate (bool): a flag indicating if there should be a validation
192+
if `url` is valid ndjson
192193
Returns:
193194
BulkImportRequest object
194195
"""
@@ -241,11 +242,13 @@ def create_from_objects(cls,
241242
}
242243
}``
243244
244-
Args:x
245+
Args:
245246
client (Client): a Labelbox client
246247
project_id (str): id of project for which predictions will be imported
247248
name (str): name of BulkImportRequest
248249
predictions (Iterable[dict]): iterable of dictionaries representing predictions
250+
validate (bool): a flag indicating if there should be a validation
251+
if `predictions` is valid ndjson
249252
Returns:
250253
BulkImportRequest object
251254
"""
@@ -313,7 +316,8 @@ def create_from_local_file(cls,
313316
return cls(client, response_data["createBulkImportRequest"])
314317

315318

316-
def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
319+
def _validate_ndjson(lines: Iterable[Dict[str, Any]],
320+
project: "labelbox.Project") -> None:
317321
"""
318322
Client side validation of an ndjson object.
319323
@@ -328,7 +332,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
328332
project (Project): id of project for which predictions will be imported
329333
330334
Raises:
331-
ValidationError: Raise for invalid NDJson
335+
MALValidationError: Raise for invalid NDJson
332336
UuidError: Duplicate UUID in upload
333337
"""
334338
data_row_ids = {
@@ -347,8 +351,8 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
347351
f'{uuid} already used in this import job, '
348352
'must be unique for the project.')
349353
uids.add(uuid)
350-
except (ValidationError, ValueError, TypeError, KeyError) as e:
351-
raise labelbox.exceptions.ValidationError(
354+
except (pydantic.ValidationError, ValueError, TypeError, KeyError) as e:
355+
raise labelbox.exceptions.MALValidationError(
352356
f"Invalid NDJson on line {idx}") from e
353357

354358

@@ -406,19 +410,19 @@ def get_mal_schemas(ontology):
406410
LabelboxID: str = pydantic.Field(..., min_length=25, max_length=25)
407411

408412

409-
class Bbox(TypedDict):
413+
class Bbox(BaseModel):
410414
top: float
411415
left: float
412416
height: float
413417
width: float
414418

415419

416-
class Point(TypedDict):
420+
class Point(BaseModel):
417421
x: float
418422
y: float
419423

420424

421-
class FrameLocation(TypedDict):
425+
class FrameLocation(BaseModel):
422426
end: int
423427
start: int
424428

@@ -428,7 +432,9 @@ class VideoSupported(BaseModel):
428432
frames: Optional[List[FrameLocation]]
429433

430434

431-
class UnionConstructor:
435+
#Base class for a special kind of union.
436+
# Compatible with pydantic. Improves error messages over a traditional union
437+
class SpecialUnion:
432438

433439
def __new__(cls, **kwargs):
434440
return cls.build(kwargs)
@@ -439,8 +445,8 @@ def __get_validators__(cls):
439445

440446
@classmethod
441447
def get_union_types(cls):
442-
if not issubclass(cls, UnionConstructor):
443-
raise TypeError("{} must be a subclass of UnionConstructor")
448+
if not issubclass(cls, SpecialUnion):
449+
raise TypeError("{} must be a subclass of SpecialUnion")
444450

445451
union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
446452
if len(union_types) < 1:
@@ -453,7 +459,16 @@ def get_union_types(cls):
453459
return union_types[0].__args__[0].__args__
454460

455461
@classmethod
456-
def build(cls: Any, data) -> "NDBase":
462+
def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase":
463+
"""
464+
Checks through all objects in the union to see which matches the input data.
465+
Args:
466+
data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union
467+
raises:
468+
KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
469+
ValidationError: Error while trying to construct a specific object in the union
470+
471+
"""
457472
if isinstance(data, BaseModel):
458473
data = data.dict()
459474

@@ -506,7 +521,6 @@ class NDFeatureSchema(BaseModel):
506521

507522
class NDBase(NDFeatureSchema):
508523
ontology_type: str
509-
schemaId: str = LabelboxID
510524
uuid: UUID
511525
dataRow: DataRow
512526

@@ -553,7 +567,7 @@ class NDText(NDBase):
553567
#No feature schema to check
554568

555569

556-
class NDCheckList(VideoSupported, NDBase):
570+
class NDChecklist(VideoSupported, NDBase):
557571
ontology_type: Literal["checklist"] = "checklist"
558572
answers: List[NDFeatureSchema] = pydantic.Field(determinant=True)
559573

@@ -566,7 +580,7 @@ def validate_answers(cls, value, field):
566580

567581
def validate_feature_schemas(self, valid_feature_schemas):
568582
#Test top level feature schema for this tool
569-
super(NDCheckList, self).validate_feature_schemas(valid_feature_schemas)
583+
super(NDChecklist, self).validate_feature_schemas(valid_feature_schemas)
570584
#Test the feature schemas provided to the answer field
571585
if len(set([answer.schemaId for answer in self.answers])) != len(
572586
self.answers):
@@ -595,9 +609,9 @@ def validate_feature_schemas(self, valid_feature_schemas):
595609

596610
#A union with custom construction logic to improve error messages
597611
class NDClassification(
598-
UnionConstructor,
612+
SpecialUnion,
599613
Type[Union[NDText, NDRadio, # type: ignore
600-
NDCheckList]]):
614+
NDChecklist]]):
601615
...
602616

603617

@@ -665,7 +679,7 @@ class NDPoint(NDBaseTool):
665679
#Could check if points are positive
666680

667681

668-
class EntityLocation(TypedDict):
682+
class EntityLocation(BaseModel):
669683
start: int
670684
end: int
671685

@@ -676,6 +690,9 @@ class NDTextEntity(NDBaseTool):
676690

677691
@validator('location')
678692
def is_valid_location(cls, v):
693+
if isinstance(v, BaseModel):
694+
v = v.dict()
695+
679696
if len(v) < 2:
680697
raise ValueError(
681698
f"A line must have at least 2 points to be valid. Found {v}")
@@ -687,7 +704,7 @@ def is_valid_location(cls, v):
687704
return v
688705

689706

690-
class MaskFeatures(TypedDict):
707+
class MaskFeatures(BaseModel):
691708
instanceURI: str
692709
colorRGB: Union[List[int], Tuple[int, int, int]]
693710

@@ -698,6 +715,9 @@ class NDMask(NDBaseTool):
698715

699716
@validator('mask')
700717
def is_valid_mask(cls, v):
718+
if isinstance(v, BaseModel):
719+
v = v.dict()
720+
701721
colors = v['colorRGB']
702722
#Does the dtype matter? Can it be a float?
703723
if not isinstance(colors, (tuple, list)):
@@ -715,15 +735,17 @@ def is_valid_mask(cls, v):
715735

716736
#A union with custom construction logic to improve error messages
717737
class NDTool(
718-
UnionConstructor,
738+
SpecialUnion,
719739
Type[Union[NDMask, # type: ignore
720740
NDTextEntity, NDPoint, NDRectangle, NDPolyline,
721741
NDPolygon,]]):
722742
...
723743

724744

725-
class NDAnnotation(UnionConstructor,
726-
Type[Union[NDTool, NDClassification]]): # type: ignore
745+
class NDAnnotation(
746+
SpecialUnion,
747+
Type[Union[NDTool, # type: ignore
748+
NDClassification]]):
727749

728750
@classmethod
729751
def build(cls: Any, data) -> "NDBase":

0 commit comments

Comments
 (0)