Skip to content

Commit 432d628

Browse files
author
Matt Sokoloff
committed
minor cleanup
1 parent ce7e1c0 commit 432d628

File tree

1 file changed

+47
-43
lines changed

1 file changed

+47
-43
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,22 @@
1-
import enum
21
import json
32
import logging
43
import time
54
from pathlib import Path
6-
import typing
75
from uuid import UUID, uuid4
86
from pydantic import BaseModel, validator
97
import pydantic
108
import backoff
119
import ndjson
12-
from pydantic.types import conlist, constr
13-
from pydantic import Required
14-
from pydantic.dataclasses import dataclass
1510
import labelbox
1611
import requests
1712
from labelbox import utils
18-
1913
from labelbox.orm import query
2014
from labelbox.orm.db_object import DbObject
2115
from labelbox.orm.model import Field
2216
from labelbox.orm.model import Relationship
2317
from labelbox.schema.enums import BulkImportRequestState
2418
from pydantic import ValidationError
25-
from typing import Any, Generic, List, Optional, BinaryIO, Dict, Iterable, Tuple, TypeVar, Union, Type, Set
19+
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, Type, Set
2620
from typing_extensions import TypedDict, Literal
2721

2822
NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -343,15 +337,15 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
343337
uids: Set[str] = set()
344338
for idx, line in enumerate(lines):
345339
try:
346-
annotation = NDAnnotation(data=line)
347-
annotation.validate(data_row_ids, feature_schemas)
348-
uuid = str(annotation.data.uuid)
340+
annotation = NDAnnotation(**line)
341+
annotation.validate_instance(data_row_ids, feature_schemas)
342+
uuid = str(annotation.uuid)
349343
if uuid in uids:
350344
raise labelbox.exceptions.UuidError(
351345
f'{uuid} already used in this import job, '
352346
'must be unique for the project.')
353347
uids.add(uuid)
354-
except (ValidationError, ValueError, KeyError) as e:
348+
except (ValidationError, ValueError, TypeError, KeyError) as e:
355349
raise labelbox.exceptions.NDJsonError(
356350
f"Invalid NDJson on line {idx}") from e
357351

@@ -482,21 +476,22 @@ def build(cls: Any, data) -> "NDBase":
482476
matched = NDText
483477
else:
484478
raise TypeError(
485-
f"Unexpected type for answer. Found {data['answer']}. Expected a string or a dict"
479+
f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict"
486480
)
487481
return matched(**data)
488482
else:
489483
raise KeyError(
490-
f"Expected classes with values {data} to have keys matching one of the following : {top_level_fields}"
484+
f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}."
491485
)
492486

493487
@classmethod
494488
def schema(cls):
495-
#TODO: Double check this to return subclasses
496-
#results.append()
489+
results = {'definitions': {}}
497490
for cl in cls.get_union_types():
498-
print(cl.schema())
499-
#return cl.schema()
491+
schema = cl.schema()
492+
results['definitions'].update(schema.pop('definitions'))
493+
results[cl.__name__] = schema
494+
return results
500495

501496

502497
class DataRow(BaseModel):
@@ -530,6 +525,10 @@ def validate_feature_schemas(self, valid_feature_schemas):
530525
f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas[self.schemaId]['tool']}"
531526
)
532527

528+
def validate_instance(self, valid_datarows, valid_feature_schemas):
529+
self.validate_feature_schemas(valid_feature_schemas)
530+
self.validate_datarow(valid_datarows)
531+
533532
class Config:
534533
#Users shouldn't to add extra data to the payload
535534
extra = 'forbid'
@@ -592,6 +591,7 @@ def validate_feature_schemas(self, valid_feature_schemas):
592591
)
593592

594593

594+
#A union with custom construction logic to improve error messages
595595
class NDClassification(
596596
UnionConstructor,
597597
Type[Union[NDText, NDRadio, # type: ignore
@@ -601,7 +601,8 @@ class NDClassification(
601601

602602
###### Tools ######
603603

604-
class BaseTool(NDBase):
604+
605+
class NDBaseTool(NDBase):
605606
classifications: List[NDClassification] = []
606607

607608
#This is indepdent of our problem
@@ -626,7 +627,7 @@ def validate_subclasses(cls, value, field):
626627
return results
627628

628629

629-
class NDPolygon(BaseTool):
630+
class NDPolygon(NDBaseTool):
630631
ontology_type: Literal["polygon"] = "polygon"
631632
polygon: List[Point] = pydantic.Field(determinant=True)
632633

@@ -638,7 +639,7 @@ def is_geom_valid(cls, v):
638639
return v
639640

640641

641-
class NDPolyline(BaseTool):
642+
class NDPolyline(NDBaseTool):
642643
ontology_type: Literal["line"] = "line"
643644
line: List[Point] = pydantic.Field(determinant=True)
644645

@@ -650,13 +651,13 @@ def is_geom_valid(cls, v):
650651
return v
651652

652653

653-
class NDRectangle(BaseTool):
654+
class NDRectangle(NDBaseTool):
654655
ontology_type: Literal["rectangle"] = "rectangle"
655656
bbox: Bbox = pydantic.Field(determinant=True)
656657
#Could check if points are positive
657658

658659

659-
class NDPoint(BaseTool):
660+
class NDPoint(NDBaseTool):
660661
ontology_type: Literal["point"] = "point"
661662
point: Point = pydantic.Field(determinant=True)
662663
#Could check if points are positive
@@ -667,7 +668,7 @@ class EntityLocation(TypedDict):
667668
end: int
668669

669670

670-
class NDTextEntity(BaseTool):
671+
class NDTextEntity(NDBaseTool):
671672
ontology_type: Literal["named-entity"] = "named-entity"
672673
location: EntityLocation = pydantic.Field(determinant=True)
673674

@@ -689,7 +690,7 @@ class MaskFeatures(TypedDict):
689690
colorRGB: Union[List[int], Tuple[int, int, int]]
690691

691692

692-
class NDMask(BaseTool):
693+
class NDMask(NDBaseTool):
693694
ontology_type: Literal["superpixel"] = "superpixel"
694695
mask: MaskFeatures = pydantic.Field(determinant=True)
695696

@@ -710,6 +711,7 @@ def is_valid_mask(cls, v):
710711
return v
711712

712713

714+
#A union with custom construction logic to improve error messages
713715
class NDTool(
714716
UnionConstructor,
715717
Type[Union[NDMask, # type: ignore
@@ -718,26 +720,28 @@ class NDTool(
718720
...
719721

720722

721-
#### Top level annotation. Can be used to construct and validate any annotation
722-
class NDAnnotation(BaseModel):
723-
data: NDBase
723+
class NDAnnotation(UnionConstructor,
724+
Type[Union[NDTool, NDClassification]]): # type: ignore
724725

725-
@validator('data', pre=True)
726-
def validate_data(cls, value):
727-
if not isinstance(value, dict):
726+
@classmethod
727+
def build(cls: Any, data) -> "NDBase":
728+
if not isinstance(data, dict):
728729
raise ValueError('value must be dict')
729-
#Catch keyerror to clean up error messages
730-
#Only raise if they both fail
731-
try:
732-
return NDTool(**value)
733-
except KeyError as e1:
730+
errors = []
731+
for cl in cls.get_union_types():
734732
try:
735-
return NDClassification(**value)
736-
except KeyError as e2:
737-
raise ValueError(
738-
f'Unable to construct tool or classification.\nTool: {e1}\nClassification: {e2}'
739-
)
733+
return cl(**data)
734+
except KeyError as e:
735+
errors.append(f"{cl.__name__}: {e}")
736+
737+
raise ValueError('Unable to construct any annotation.\n{}'.format(
738+
"\n".join(errors)))
740739

741-
def validate(self, valid_datarows, valid_feature_schemas):
742-
self.data.validate_feature_schemas(valid_feature_schemas)
743-
self.data.validate_datarow(valid_datarows)
740+
@classmethod
741+
def schema(cls):
742+
data = {'definitions': {}}
743+
for type_ in cls.get_union_types():
744+
schema_ = type_.schema()
745+
data['definitions'].update(schema_.pop('definitions'))
746+
data[type_.__name__] = schema_
747+
return data

0 commit comments

Comments
 (0)