Skip to content

Commit 34495c6

Browse files
author
Matt Sokoloff
committed
Merge branch 'ms/validation' of https://github.com/Labelbox/labelbox-python into ms/validation-part2
2 parents 9637c81 + 432d628 commit 34495c6

File tree

3 files changed

+133
-82
lines changed

3 files changed

+133
-82
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 130 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import enum
21
import json
32
import logging
43
import time
@@ -8,17 +7,16 @@
87
import pydantic
98
import backoff
109
import ndjson
11-
from pydantic.types import conlist, constr
10+
import labelbox
1211
import requests
1312
from labelbox import utils
14-
import labelbox.exceptions
1513
from labelbox.orm import query
1614
from labelbox.orm.db_object import DbObject
1715
from labelbox.orm.model import Field
1816
from labelbox.orm.model import Relationship
1917
from labelbox.schema.enums import BulkImportRequestState
2018
from pydantic import ValidationError
21-
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union
19+
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, Type, Set
2220
from typing_extensions import TypedDict, Literal
2321

2422
NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -336,18 +334,18 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
336334
for data_row in dataset.data_rows()
337335
}
338336
feature_schemas = get_mal_schemas(project.ontology())
339-
uids = set()
337+
uids: Set[str] = set()
340338
for idx, line in enumerate(lines):
341339
try:
342-
annotation = NDAnnotation(data=line)
343-
annotation.validate(data_row_ids, feature_schemas)
344-
uuid = annotation.data.uuid
340+
annotation = NDAnnotation(**line)
341+
annotation.validate_instance(data_row_ids, feature_schemas)
342+
uuid = str(annotation.uuid)
345343
if uuid in uids:
346344
raise labelbox.exceptions.UuidError(
347345
f'{uuid} already used in this import job, '
348346
'must be unique for the project.')
349347
uids.add(uuid)
350-
except (ValidationError, ValueError, KeyError) as e:
348+
except (ValidationError, ValueError, TypeError, KeyError) as e:
351349
raise labelbox.exceptions.NDJsonError(
352350
f"Invalid NDJson on line {idx}") from e
353351

@@ -403,7 +401,7 @@ def get_mal_schemas(ontology):
403401
return valid_feature_schemas
404402

405403

406-
LabelboxID = constr(min_length=25, max_length=25, strict=True)
404+
LabelboxID: str = pydantic.Field(..., min_length=25, max_length=25)
407405

408406

409407
class Bbox(TypedDict):
@@ -418,28 +416,50 @@ class Point(TypedDict):
418416
y: float
419417

420418

419+
class FrameLocation(TypedDict):
420+
end: int
421+
start: int
422+
423+
421424
class VideoSupported(BaseModel):
422425
#Note that frames are only allowed as top level inferences for video
423-
frames: Optional[List[TypedDict("frames", {"end": int, "start": int})]]
426+
frames: Optional[List[FrameLocation]]
424427

425428

426429
class UnionConstructor:
427-
types: List["NDBase"]
430+
431+
def __new__(cls, **kwargs):
432+
return cls.build(kwargs)
428433

429434
@classmethod
430435
def __get_validators__(cls):
431436
yield cls.build
432437

433438
@classmethod
434-
def build(cls, data):
439+
def get_union_types(cls):
440+
if not issubclass(cls, UnionConstructor):
441+
raise TypeError("{} must be a subclass of UnionConstructor")
442+
443+
union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
444+
if len(union_types) < 1:
445+
raise TypeError(
446+
"Class {cls} should inherit from a union of objects to build")
447+
if len(union_types) > 1:
448+
raise TypeError(
449+
f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}"
450+
)
451+
return union_types[0].__args__[0].__args__
452+
453+
@classmethod
454+
def build(cls: Any, data) -> "NDBase":
435455
if isinstance(data, BaseModel):
436456
data = data.dict()
437457

438458
top_level_fields = []
439459
max_match = 0
440460
matched = None
441461

442-
for type_ in cls.types:
462+
for type_ in cls.get_union_types():
443463
determinate_fields = type_.Config.determinants(type_)
444464
top_level_fields.append(determinate_fields)
445465
matches = sum([val in determinate_fields for val in data])
@@ -455,26 +475,43 @@ def build(cls, data):
455475
elif isinstance(data['answer'], str):
456476
matched = NDText
457477
else:
458-
raise ValidationError(
459-
f"Unexpected type for answer. Found {data['answer']}. Expected a string or a dict"
478+
raise TypeError(
479+
f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict"
460480
)
461481
return matched(**data)
462482
else:
463483
raise KeyError(
464-
f"Expected classes with values {data} to have keys matching one of the following : {top_level_fields}"
484+
f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}."
465485
)
466486

487+
@classmethod
488+
def schema(cls):
489+
results = {'definitions': {}}
490+
for cl in cls.get_union_types():
491+
schema = cl.schema()
492+
results['definitions'].update(schema.pop('definitions'))
493+
results[cl.__name__] = schema
494+
return results
495+
496+
497+
class DataRow(BaseModel):
498+
id: str
499+
500+
501+
class NDFeatureSchema(BaseModel):
502+
schemaId: str = LabelboxID
503+
467504

468-
class NDBase(BaseModel):
505+
class NDBase(NDFeatureSchema):
469506
ontology_type: str
470-
schemaId: LabelboxID
507+
schemaId: str = LabelboxID
471508
uuid: UUID
472-
dataRow: TypedDict('dataRow', {'id': LabelboxID})
509+
dataRow: DataRow
473510

474511
def validate_datarow(self, valid_datarows):
475-
if self.dataRow['id'] not in valid_datarows:
512+
if self.dataRow.id not in valid_datarows:
476513
raise ValueError(
477-
f"datarow {self.dataRow['id']} is not attached to the specified project"
514+
f"datarow {self.dataRow.id} is not attached to the specified project"
478515
)
479516

480517
def validate_feature_schemas(self, valid_feature_schemas):
@@ -488,6 +525,10 @@ def validate_feature_schemas(self, valid_feature_schemas):
488525
f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas[self.schemaId]['tool']}"
489526
)
490527

528+
def validate_instance(self, valid_datarows, valid_feature_schemas):
529+
self.validate_feature_schemas(valid_feature_schemas)
530+
self.validate_datarow(valid_datarows)
531+
491532
class Config:
492533
#Users shouldn't to add extra data to the payload
493534
extra = 'forbid'
@@ -512,49 +553,57 @@ class NDText(NDBase):
512553

513554
class NDCheckList(VideoSupported, NDBase):
514555
ontology_type: Literal["checklist"] = "checklist"
515-
answers: conlist(TypedDict('schemaId', {'schemaId': LabelboxID}),
516-
min_items=1) = pydantic.Field(determinant=True)
556+
answers: List[NDFeatureSchema] = pydantic.Field(determinant=True)
557+
558+
@validator('answers', pre=True)
559+
def validate_answers(cls, value, field):
560+
#constr not working with mypy.
561+
if not len(value):
562+
raise ValueError("Checklist answers should not be empty")
563+
return value
517564

518565
def validate_feature_schemas(self, valid_feature_schemas):
519566
#Test top level feature schema for this tool
520567
super(NDCheckList, self).validate_feature_schemas(valid_feature_schemas)
521568
#Test the feature schemas provided to the answer field
522-
if len(set([answer['schemaId'] for answer in self.answers])) != len(
569+
if len(set([answer.schemaId for answer in self.answers])) != len(
523570
self.answers):
524571
raise ValueError(
525572
f"Duplicated featureSchema found for checklist {self.uuid}")
526573
for answer in self.answers:
527574
options = valid_feature_schemas[self.schemaId]['options']
528-
if answer['schemaId'] not in options:
575+
if answer.schemaId not in options:
529576
raise ValueError(
530577
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}"
531578
)
532579

533580

534581
class NDRadio(VideoSupported, NDBase):
535582
ontology_type: Literal["radio"] = "radio"
536-
answer: TypedDict(
537-
'schemaId', {'schemaId': LabelboxID}) = pydantic.Field(determinant=True)
583+
answer: NDFeatureSchema = pydantic.Field(determinant=True)
538584

539585
def validate_feature_schemas(self, valid_feature_schemas):
540586
super(NDRadio, self).validate_feature_schemas(valid_feature_schemas)
541587
options = valid_feature_schemas[self.schemaId]['options']
542-
if self.answer['schemaId'] not in options:
588+
if self.answer.schemaId not in options:
543589
raise ValueError(
544-
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer['schemaId']}"
590+
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.schemaId}"
545591
)
546592

547593

548-
class NDClassification(UnionConstructor):
549-
#Represents both subclasses and top level classifications
550-
types = [NDText, NDRadio, NDCheckList]
594+
#A union with custom construction logic to improve error messages
595+
class NDClassification(
596+
UnionConstructor,
597+
Type[Union[NDText, NDRadio, # type: ignore
598+
NDCheckList]]):
599+
...
551600

552601

553602
###### Tools ######
554603

555604

556-
class BaseTool(NDBase):
557-
classifications: List["NDClassification"] = []
605+
class NDBaseTool(NDBase):
606+
classifications: List[NDClassification] = []
558607

559608
#This is indepdent of our problem
560609
def validate_feature_schemas(self, valid_feature_schemas):
@@ -578,7 +627,7 @@ def validate_subclasses(cls, value, field):
578627
return results
579628

580629

581-
class NDPolygon(BaseTool):
630+
class NDPolygon(NDBaseTool):
582631
ontology_type: Literal["polygon"] = "polygon"
583632
polygon: List[Point] = pydantic.Field(determinant=True)
584633

@@ -590,7 +639,7 @@ def is_geom_valid(cls, v):
590639
return v
591640

592641

593-
class NDPolyline(BaseTool):
642+
class NDPolyline(NDBaseTool):
594643
ontology_type: Literal["line"] = "line"
595644
line: List[Point] = pydantic.Field(determinant=True)
596645

@@ -602,24 +651,26 @@ def is_geom_valid(cls, v):
602651
return v
603652

604653

605-
class NDRectangle(BaseTool):
654+
class NDRectangle(NDBaseTool):
606655
ontology_type: Literal["rectangle"] = "rectangle"
607656
bbox: Bbox = pydantic.Field(determinant=True)
608657
#Could check if points are positive
609658

610659

611-
class NDPoint(BaseTool):
660+
class NDPoint(NDBaseTool):
612661
ontology_type: Literal["point"] = "point"
613662
point: Point = pydantic.Field(determinant=True)
614663
#Could check if points are positive
615664

616665

617-
class NDTextEntity(BaseTool):
666+
class EntityLocation(TypedDict):
667+
start: int
668+
end: int
669+
670+
671+
class NDTextEntity(NDBaseTool):
618672
ontology_type: Literal["named-entity"] = "named-entity"
619-
location: TypedDict("TextLocation", {
620-
'start': int,
621-
'end': int
622-
}) = pydantic.Field(determinant=True)
673+
location: EntityLocation = pydantic.Field(determinant=True)
623674

624675
@validator('location')
625676
def is_valid_location(cls, v):
@@ -634,13 +685,14 @@ def is_valid_location(cls, v):
634685
return v
635686

636687

637-
class NDMask(BaseTool):
688+
class MaskFeatures(TypedDict):
689+
instanceURI: str
690+
colorRGB: Union[List[int], Tuple[int, int, int]]
691+
692+
693+
class NDMask(NDBaseTool):
638694
ontology_type: Literal["superpixel"] = "superpixel"
639-
mask: TypedDict(
640-
"mask", {
641-
'instanceURI': constr(min_length=5, strict=True),
642-
'colorRGB': Tuple[int, int, int]
643-
}) = pydantic.Field(determinant=True)
695+
mask: MaskFeatures = pydantic.Field(determinant=True)
644696

645697
@validator('mask')
646698
def is_valid_mask(cls, v):
@@ -659,34 +711,37 @@ def is_valid_mask(cls, v):
659711
return v
660712

661713

662-
class NDTool(UnionConstructor):
663-
#Tools and top level classifications
664-
types = [
665-
NDMask, NDTextEntity, NDPoint, NDRectangle, NDPolyline, NDPolygon,
666-
*NDClassification.types
667-
]
714+
#A union with custom construction logic to improve error messages
715+
class NDTool(
716+
UnionConstructor,
717+
Type[Union[NDMask, # type: ignore
718+
NDTextEntity, NDPoint, NDRectangle, NDPolyline,
719+
NDPolygon,]]):
720+
...
668721

669722

670-
#### Top level annotation. Can be used to construct and validate any annotation
671-
class NDAnnotation(BaseModel):
672-
data: Union[NDTool, NDClassification]
723+
class NDAnnotation(UnionConstructor,
724+
Type[Union[NDTool, NDClassification]]): # type: ignore
673725

674-
@validator('data', pre=True)
675-
def validate_data(cls, value):
676-
if not isinstance(value, dict):
726+
@classmethod
727+
def build(cls: Any, data) -> "NDBase":
728+
if not isinstance(data, dict):
677729
raise ValueError('value must be dict')
678-
#Catch keyerror to clean up error messages
679-
#Only raise if they both fail
680-
try:
681-
return NDTool.build(value)
682-
except KeyError as e1:
730+
errors = []
731+
for cl in cls.get_union_types():
683732
try:
684-
return NDClassification.build(value)
685-
except KeyError as e2:
686-
raise ValueError(
687-
f'Unable to construct tool or classification.\nTool: {e1}\nClassification: {e2}'
688-
)
733+
return cl(**data)
734+
except KeyError as e:
735+
errors.append(f"{cl.__name__}: {e}")
736+
737+
raise ValueError('Unable to construct any annotation.\n{}'.format(
738+
"\n".join(errors)))
689739

690-
def validate(self, valid_datarows, valid_feature_schemas):
691-
self.data.validate_feature_schemas(valid_feature_schemas)
692-
self.data.validate_datarow(valid_datarows)
740+
@classmethod
741+
def schema(cls):
742+
data = {'definitions': {}}
743+
for type_ in cls.get_union_types():
744+
schema_ = type_.schema()
745+
data['definitions'].update(schema_.pop('definitions'))
746+
data[type_.__name__] = schema_
747+
return data

setup.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@
2222
install_requires=[
2323
"backoff==1.10.0",
2424
"backports-datetime-fromisoformat==1.0.0; python_version < '3.7.0'",
25-
"dataclasses==0.7; python_version < '3.7.0'",
26-
"ndjson==0.3.1",
27-
"requests>=2.22.0",
28-
"google-api-core>=1.22.1",
29-
"pydantic"
25+
"dataclasses==0.7; python_version < '3.7.0'", "ndjson==0.3.1",
26+
"requests>=2.22.0", "google-api-core>=1.22.1", "pydantic"
3027
],
3128
classifiers=[
3229
'Development Status :: 3 - Alpha',

0 commit comments

Comments
 (0)