Skip to content

Commit 918d489

Browse files
author
Matt Sokoloff
committed
mypy changes
1 parent e6e0cfe commit 918d489

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import backoff
1010
import ndjson
1111
from pydantic.types import conlist, constr
12+
from pydantic import Required
1213
import requests
1314
from labelbox import utils
1415
import labelbox.exceptions
@@ -18,7 +19,7 @@
1819
from labelbox.orm.model import Relationship
1920
from labelbox.schema.enums import BulkImportRequestState
2021
from pydantic import ValidationError
21-
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union
22+
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, Type, Set
2223
from typing_extensions import TypedDict, Literal
2324

2425
NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -336,7 +337,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
336337
for data_row in dataset.data_rows()
337338
}
338339
feature_schemas = get_mal_schemas(project.ontology())
339-
uids = set()
340+
uids: Set[str] = set()
340341
for idx, line in enumerate(lines):
341342
try:
342343
annotation = NDAnnotation(data=line)
@@ -346,7 +347,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
346347
raise labelbox.exceptions.UuidError(
347348
f'{uuid} already used in this import job, '
348349
'must be unique for the project.')
349-
uids.add(uuid)
350+
uids.add(str(uuid))
350351
except (ValidationError, ValueError, KeyError) as e:
351352
raise labelbox.exceptions.NDJsonError(
352353
f"Invalid NDJson on line {idx}") from e
@@ -403,7 +404,7 @@ def get_mal_schemas(ontology):
403404
return valid_feature_schemas
404405

405406

406-
LabelboxID = constr(min_length=25, max_length=25, strict=True)
407+
LabelboxID: str = pydantic.Field(..., min_length=25, max_length=25)
407408

408409

409410
class Bbox(TypedDict):
@@ -418,20 +419,25 @@ class Point(TypedDict):
418419
y: float
419420

420421

422+
class FrameLocation(TypedDict):
423+
end: int
424+
start: int
425+
426+
421427
class VideoSupported(BaseModel):
422428
#Note that frames are only allowed as top level inferences for video
423-
frames: Optional[List[TypedDict("frames", {"end": int, "start": int})]]
429+
frames: Optional[List[FrameLocation]]
424430

425431

426432
class UnionConstructor:
427-
types: List["NDBase"]
433+
types: Iterable[Type["NDBase"]]
428434

429435
@classmethod
430436
def __get_validators__(cls):
431437
yield cls.build
432438

433439
@classmethod
434-
def build(cls, data):
440+
def build(cls, data) -> "NDBase":
435441
if isinstance(data, BaseModel):
436442
data = data.dict()
437443

@@ -455,7 +461,7 @@ def build(cls, data):
455461
elif isinstance(data['answer'], str):
456462
matched = NDText
457463
else:
458-
raise ValidationError(
464+
raise TypeError(
459465
f"Unexpected type for answer. Found {data['answer']}. Expected a string or a dict"
460466
)
461467
return matched(**data)
@@ -465,16 +471,24 @@ def build(cls, data):
465471
)
466472

467473

468-
class NDBase(BaseModel):
474+
class DataRow(BaseModel):
475+
id: str
476+
477+
478+
class NDFeatureSchema(BaseModel):
479+
schemaId: str = LabelboxID
480+
481+
482+
class NDBase(NDFeatureSchema):
469483
ontology_type: str
470-
schemaId: LabelboxID
484+
schemaId: str = LabelboxID
471485
uuid: UUID
472-
dataRow: TypedDict('dataRow', {'id': LabelboxID})
486+
dataRow: DataRow
473487

474488
def validate_datarow(self, valid_datarows):
475-
if self.dataRow['id'] not in valid_datarows:
489+
if self.dataRow.id not in valid_datarows:
476490
raise ValueError(
477-
f"datarow {self.dataRow['id']} is not attached to the specified project"
491+
f"datarow {self.dataRow.id} is not attached to the specified project"
478492
)
479493

480494
def validate_feature_schemas(self, valid_feature_schemas):
@@ -493,7 +507,7 @@ class Config:
493507
extra = 'forbid'
494508

495509
@staticmethod
496-
def determinants(parent_cls) -> None:
510+
def determinants(parent_cls) -> List[str]:
497511
#This is a hack for better error messages
498512
return [
499513
k for k, v in parent_cls.__fields__.items()
@@ -512,42 +526,47 @@ class NDText(NDBase):
512526

513527
class NDCheckList(VideoSupported, NDBase):
514528
ontology_type: Literal["checklist"] = "checklist"
515-
answers: conlist(TypedDict('schemaId', {'schemaId': LabelboxID}),
516-
min_items=1) = pydantic.Field(determinant=True)
529+
answers: List[NDFeatureSchema] = pydantic.Field(determinant=True)
530+
531+
@validator('answers', pre=True)
532+
def validate_answers(cls, value, field):
533+
#constr not working with mypy.
534+
if not len(value):
535+
raise ValueError("Checklist answers should not be empty")
536+
return value
517537

518538
def validate_feature_schemas(self, valid_feature_schemas):
519539
#Test top level feature schema for this tool
520540
super(NDCheckList, self).validate_feature_schemas(valid_feature_schemas)
521541
#Test the feature schemas provided to the answer field
522-
if len(set([answer['schemaId'] for answer in self.answers])) != len(
542+
if len(set([answer.schemaId for answer in self.answers])) != len(
523543
self.answers):
524544
raise ValueError(
525545
f"Duplicated featureSchema found for checklist {self.uuid}")
526546
for answer in self.answers:
527547
options = valid_feature_schemas[self.schemaId]['options']
528-
if answer['schemaId'] not in options:
548+
if answer.schemaId not in options:
529549
raise ValueError(
530550
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}"
531551
)
532552

533553

534554
class NDRadio(VideoSupported, NDBase):
535555
ontology_type: Literal["radio"] = "radio"
536-
answer: TypedDict(
537-
'schemaId', {'schemaId': LabelboxID}) = pydantic.Field(determinant=True)
556+
answer: NDFeatureSchema = pydantic.Field(determinant=True)
538557

539558
def validate_feature_schemas(self, valid_feature_schemas):
540559
super(NDRadio, self).validate_feature_schemas(valid_feature_schemas)
541560
options = valid_feature_schemas[self.schemaId]['options']
542-
if self.answer['schemaId'] not in options:
561+
if self.answer.schemaId not in options:
543562
raise ValueError(
544-
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer['schemaId']}"
563+
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.schemaId}"
545564
)
546565

547566

548567
class NDClassification(UnionConstructor):
549568
#Represents both subclasses and top level classifications
550-
types = [NDText, NDRadio, NDCheckList]
569+
types: Iterable[Type[NDBase]] = {NDText, NDRadio, NDCheckList}
551570

552571

553572
###### Tools ######
@@ -614,12 +633,14 @@ class NDPoint(BaseTool):
614633
#Could check if points are positive
615634

616635

636+
class EntityLocation(TypedDict):
637+
start: int
638+
end: int
639+
640+
617641
class NDTextEntity(BaseTool):
618642
ontology_type: Literal["named-entity"] = "named-entity"
619-
location: TypedDict("TextLocation", {
620-
'start': int,
621-
'end': int
622-
}) = pydantic.Field(determinant=True)
643+
location: EntityLocation = pydantic.Field(determinant=True)
623644

624645
@validator('location')
625646
def is_valid_location(cls, v):
@@ -634,13 +655,14 @@ def is_valid_location(cls, v):
634655
return v
635656

636657

658+
class MaskFeatures(TypedDict):
659+
instanceURI: str
660+
colorRGB: Union[List[int], Tuple[int, int, int]]
661+
662+
637663
class NDMask(BaseTool):
638664
ontology_type: Literal["superpixel"] = "superpixel"
639-
mask: TypedDict(
640-
"mask", {
641-
'instanceURI': constr(min_length=5, strict=True),
642-
'colorRGB': Tuple[int, int, int]
643-
}) = pydantic.Field(determinant=True)
665+
mask: MaskFeatures = pydantic.Field(determinant=True)
644666

645667
@validator('mask')
646668
def is_valid_mask(cls, v):
@@ -661,15 +683,19 @@ def is_valid_mask(cls, v):
661683

662684
class NDTool(UnionConstructor):
663685
#Tools and top level classifications
664-
types = [
665-
NDMask, NDTextEntity, NDPoint, NDRectangle, NDPolyline, NDPolygon,
666-
*NDClassification.types
667-
]
686+
types: Iterable[Type[NDBase]] = {
687+
NDMask,
688+
NDTextEntity,
689+
NDPoint,
690+
NDRectangle,
691+
NDPolyline,
692+
NDPolygon,
693+
}
668694

669695

670696
#### Top level annotation. Can be used to construct and validate any annotation
671697
class NDAnnotation(BaseModel):
672-
data: Union[NDTool, NDClassification]
698+
data: NDBase
673699

674700
@validator('data', pre=True)
675701
def validate_data(cls, value):

tests/integration/test_ndjon_validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def test_incorrect_mask(segmentation_inference, configured_project):
142142
with pytest.raises(NDJsonError):
143143
_validate_ndjson([seg], configured_project)
144144

145-
seg['mask']['colorRGB'] = [0, 0, 0]
146-
seg['mask']['instanceURI'] = 1
145+
seg['mask']['colorRGB'] = [0, 0]
147146
with pytest.raises(NDJsonError):
148147
_validate_ndjson([seg], configured_project)
149148

0 commit comments

Comments
 (0)