1- import enum
21import json
32import logging
43import time
87import pydantic
98import backoff
109import ndjson
11- from pydantic . types import conlist , constr
10+ import labelbox
1211import requests
1312from labelbox import utils
14- import labelbox .exceptions
1513from labelbox .orm import query
1614from labelbox .orm .db_object import DbObject
1715from labelbox .orm .model import Field
1816from labelbox .orm .model import Relationship
1917from labelbox .schema .enums import BulkImportRequestState
2018from pydantic import ValidationError
21- from typing import Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union
19+ from typing import Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union , Type , Set
2220from typing_extensions import TypedDict , Literal
2321
2422NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -336,18 +334,18 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
336334 for data_row in dataset .data_rows ()
337335 }
338336 feature_schemas = get_mal_schemas (project .ontology ())
339- uids = set ()
337+ uids : Set [ str ] = set ()
340338 for idx , line in enumerate (lines ):
341339 try :
342- annotation = NDAnnotation (data = line )
343- annotation .validate (data_row_ids , feature_schemas )
344- uuid = annotation .data . uuid
340+ annotation = NDAnnotation (** line )
341+ annotation .validate_instance (data_row_ids , feature_schemas )
342+ uuid = str ( annotation .uuid )
345343 if uuid in uids :
346344 raise labelbox .exceptions .UuidError (
347345 f'{ uuid } already used in this import job, '
348346 'must be unique for the project.' )
349347 uids .add (uuid )
350- except (ValidationError , ValueError , KeyError ) as e :
348+ except (ValidationError , ValueError , TypeError , KeyError ) as e :
351349 raise labelbox .exceptions .NDJsonError (
352350 f"Invalid NDJson on line { idx } " ) from e
353351
@@ -403,7 +401,7 @@ def get_mal_schemas(ontology):
403401 return valid_feature_schemas
404402
405403
406- LabelboxID = constr ( min_length = 25 , max_length = 25 , strict = True )
404+ LabelboxID : str = pydantic . Field (..., min_length = 25 , max_length = 25 )
407405
408406
409407class Bbox (TypedDict ):
@@ -418,28 +416,50 @@ class Point(TypedDict):
418416 y : float
419417
420418
419+ class FrameLocation (TypedDict ):
420+ end : int
421+ start : int
422+
423+
421424class VideoSupported (BaseModel ):
422425 #Note that frames are only allowed as top level inferences for video
423- frames : Optional [List [TypedDict ( "frames" , { "end" : int , "start" : int }) ]]
426+ frames : Optional [List [FrameLocation ]]
424427
425428
426429class UnionConstructor :
427- types : List ["NDBase" ]
430+
431+ def __new__ (cls , ** kwargs ):
432+ return cls .build (kwargs )
428433
429434 @classmethod
430435 def __get_validators__ (cls ):
431436 yield cls .build
432437
433438 @classmethod
434- def build (cls , data ):
439+ def get_union_types (cls ):
440+ if not issubclass (cls , UnionConstructor ):
441+ raise TypeError ("{} must be a subclass of UnionConstructor" )
442+
443+ union_types = [x for x in cls .__orig_bases__ if hasattr (x , "__args__" )]
444+ if len (union_types ) < 1 :
445+ raise TypeError (
446+ "Class {cls} should inherit from a union of objects to build" )
447+ if len (union_types ) > 1 :
448+ raise TypeError (
449+ f"Class { cls } should inherit from exactly one union of objects to build. Found { union_types } "
450+ )
451+ return union_types [0 ].__args__ [0 ].__args__
452+
453+ @classmethod
454+ def build (cls : Any , data ) -> "NDBase" :
435455 if isinstance (data , BaseModel ):
436456 data = data .dict ()
437457
438458 top_level_fields = []
439459 max_match = 0
440460 matched = None
441461
442- for type_ in cls .types :
462+ for type_ in cls .get_union_types () :
443463 determinate_fields = type_ .Config .determinants (type_ )
444464 top_level_fields .append (determinate_fields )
445465 matches = sum ([val in determinate_fields for val in data ])
@@ -455,26 +475,43 @@ def build(cls, data):
455475 elif isinstance (data ['answer' ], str ):
456476 matched = NDText
457477 else :
458- raise ValidationError (
459- f"Unexpected type for answer. Found { data ['answer' ]} . Expected a string or a dict"
478+ raise TypeError (
479+ f"Unexpected type for answer field . Found { data ['answer' ]} . Expected a string or a dict"
460480 )
461481 return matched (** data )
462482 else :
463483 raise KeyError (
464- f"Expected classes with values { data } to have keys matching one of the following : { top_level_fields } "
484+ f"Invalid annotation. Must have one of the following keys : { top_level_fields } . Found { data } . "
465485 )
466486
487+ @classmethod
488+ def schema (cls ):
489+ results = {'definitions' : {}}
490+ for cl in cls .get_union_types ():
491+ schema = cl .schema ()
492+ results ['definitions' ].update (schema .pop ('definitions' ))
493+ results [cl .__name__ ] = schema
494+ return results
495+
496+
497+ class DataRow (BaseModel ):
498+ id : str
499+
500+
501+ class NDFeatureSchema (BaseModel ):
502+ schemaId : str = LabelboxID
503+
467504
468- class NDBase (BaseModel ):
505+ class NDBase (NDFeatureSchema ):
469506 ontology_type : str
470- schemaId : LabelboxID
507+ schemaId : str = LabelboxID
471508 uuid : UUID
472- dataRow : TypedDict ( 'dataRow' , { 'id' : LabelboxID })
509+ dataRow : DataRow
473510
474511 def validate_datarow (self , valid_datarows ):
475- if self .dataRow [ 'id' ] not in valid_datarows :
512+ if self .dataRow . id not in valid_datarows :
476513 raise ValueError (
477- f"datarow { self .dataRow [ 'id' ] } is not attached to the specified project"
514+ f"datarow { self .dataRow . id } is not attached to the specified project"
478515 )
479516
480517 def validate_feature_schemas (self , valid_feature_schemas ):
@@ -488,6 +525,10 @@ def validate_feature_schemas(self, valid_feature_schemas):
488525 f"Schema id { self .schemaId } does not map to the assigned tool { valid_feature_schemas [self .schemaId ]['tool' ]} "
489526 )
490527
528+ def validate_instance (self , valid_datarows , valid_feature_schemas ):
529+ self .validate_feature_schemas (valid_feature_schemas )
530+ self .validate_datarow (valid_datarows )
531+
491532 class Config :
492533 #Users shouldn't to add extra data to the payload
493534 extra = 'forbid'
@@ -512,49 +553,57 @@ class NDText(NDBase):
512553
513554class NDCheckList (VideoSupported , NDBase ):
514555 ontology_type : Literal ["checklist" ] = "checklist"
515- answers : conlist (TypedDict ('schemaId' , {'schemaId' : LabelboxID }),
516- min_items = 1 ) = pydantic .Field (determinant = True )
556+ answers : List [NDFeatureSchema ] = pydantic .Field (determinant = True )
557+
558+ @validator ('answers' , pre = True )
559+ def validate_answers (cls , value , field ):
560+ #constr not working with mypy.
561+ if not len (value ):
562+ raise ValueError ("Checklist answers should not be empty" )
563+ return value
517564
518565 def validate_feature_schemas (self , valid_feature_schemas ):
519566 #Test top level feature schema for this tool
520567 super (NDCheckList , self ).validate_feature_schemas (valid_feature_schemas )
521568 #Test the feature schemas provided to the answer field
522- if len (set ([answer [ ' schemaId' ] for answer in self .answers ])) != len (
569+ if len (set ([answer . schemaId for answer in self .answers ])) != len (
523570 self .answers ):
524571 raise ValueError (
525572 f"Duplicated featureSchema found for checklist { self .uuid } " )
526573 for answer in self .answers :
527574 options = valid_feature_schemas [self .schemaId ]['options' ]
528- if answer [ ' schemaId' ] not in options :
575+ if answer . schemaId not in options :
529576 raise ValueError (
530577 f"Feature schema provided to { self .ontology_type } invalid. Expected on of { options } . Found { answer } "
531578 )
532579
533580
534581class NDRadio (VideoSupported , NDBase ):
535582 ontology_type : Literal ["radio" ] = "radio"
536- answer : TypedDict (
537- 'schemaId' , {'schemaId' : LabelboxID }) = pydantic .Field (determinant = True )
583+ answer : NDFeatureSchema = pydantic .Field (determinant = True )
538584
539585 def validate_feature_schemas (self , valid_feature_schemas ):
540586 super (NDRadio , self ).validate_feature_schemas (valid_feature_schemas )
541587 options = valid_feature_schemas [self .schemaId ]['options' ]
542- if self .answer [ ' schemaId' ] not in options :
588+ if self .answer . schemaId not in options :
543589 raise ValueError (
544- f"Feature schema provided to { self .ontology_type } invalid. Expected on of { options } . Found { self .answer [ ' schemaId' ] } "
590+ f"Feature schema provided to { self .ontology_type } invalid. Expected on of { options } . Found { self .answer . schemaId } "
545591 )
546592
547593
548- class NDClassification (UnionConstructor ):
549- #Represents both subclasses and top level classifications
550- types = [NDText , NDRadio , NDCheckList ]
594+ #A union with custom construction logic to improve error messages
595+ class NDClassification (
596+ UnionConstructor ,
597+ Type [Union [NDText , NDRadio , # type: ignore
598+ NDCheckList ]]):
599+ ...
551600
552601
553602###### Tools ######
554603
555604
556- class BaseTool (NDBase ):
557- classifications : List [" NDClassification" ] = []
605+ class NDBaseTool (NDBase ):
606+ classifications : List [NDClassification ] = []
558607
559608 #This is indepdent of our problem
560609 def validate_feature_schemas (self , valid_feature_schemas ):
@@ -578,7 +627,7 @@ def validate_subclasses(cls, value, field):
578627 return results
579628
580629
581- class NDPolygon (BaseTool ):
630+ class NDPolygon (NDBaseTool ):
582631 ontology_type : Literal ["polygon" ] = "polygon"
583632 polygon : List [Point ] = pydantic .Field (determinant = True )
584633
@@ -590,7 +639,7 @@ def is_geom_valid(cls, v):
590639 return v
591640
592641
593- class NDPolyline (BaseTool ):
642+ class NDPolyline (NDBaseTool ):
594643 ontology_type : Literal ["line" ] = "line"
595644 line : List [Point ] = pydantic .Field (determinant = True )
596645
@@ -602,24 +651,26 @@ def is_geom_valid(cls, v):
602651 return v
603652
604653
605- class NDRectangle (BaseTool ):
654+ class NDRectangle (NDBaseTool ):
606655 ontology_type : Literal ["rectangle" ] = "rectangle"
607656 bbox : Bbox = pydantic .Field (determinant = True )
608657 #Could check if points are positive
609658
610659
611- class NDPoint (BaseTool ):
660+ class NDPoint (NDBaseTool ):
612661 ontology_type : Literal ["point" ] = "point"
613662 point : Point = pydantic .Field (determinant = True )
614663 #Could check if points are positive
615664
616665
617- class NDTextEntity (BaseTool ):
666+ class EntityLocation (TypedDict ):
667+ start : int
668+ end : int
669+
670+
671+ class NDTextEntity (NDBaseTool ):
618672 ontology_type : Literal ["named-entity" ] = "named-entity"
619- location : TypedDict ("TextLocation" , {
620- 'start' : int ,
621- 'end' : int
622- }) = pydantic .Field (determinant = True )
673+ location : EntityLocation = pydantic .Field (determinant = True )
623674
624675 @validator ('location' )
625676 def is_valid_location (cls , v ):
@@ -634,13 +685,14 @@ def is_valid_location(cls, v):
634685 return v
635686
636687
637- class NDMask (BaseTool ):
688+ class MaskFeatures (TypedDict ):
689+ instanceURI : str
690+ colorRGB : Union [List [int ], Tuple [int , int , int ]]
691+
692+
693+ class NDMask (NDBaseTool ):
638694 ontology_type : Literal ["superpixel" ] = "superpixel"
639- mask : TypedDict (
640- "mask" , {
641- 'instanceURI' : constr (min_length = 5 , strict = True ),
642- 'colorRGB' : Tuple [int , int , int ]
643- }) = pydantic .Field (determinant = True )
695+ mask : MaskFeatures = pydantic .Field (determinant = True )
644696
645697 @validator ('mask' )
646698 def is_valid_mask (cls , v ):
@@ -659,34 +711,37 @@ def is_valid_mask(cls, v):
659711 return v
660712
661713
662- class NDTool (UnionConstructor ):
663- #Tools and top level classifications
664- types = [
665- NDMask , NDTextEntity , NDPoint , NDRectangle , NDPolyline , NDPolygon ,
666- * NDClassification .types
667- ]
714+ #A union with custom construction logic to improve error messages
715+ class NDTool (
716+ UnionConstructor ,
717+ Type [Union [NDMask , # type: ignore
718+ NDTextEntity , NDPoint , NDRectangle , NDPolyline ,
719+ NDPolygon ,]]):
720+ ...
668721
669722
670- #### Top level annotation. Can be used to construct and validate any annotation
671- class NDAnnotation (BaseModel ):
672- data : Union [NDTool , NDClassification ]
723+ class NDAnnotation (UnionConstructor ,
724+ Type [Union [NDTool , NDClassification ]]): # type: ignore
673725
674- @validator ( 'data' , pre = True )
675- def validate_data (cls , value ) :
676- if not isinstance (value , dict ):
726+ @classmethod
727+ def build (cls : Any , data ) -> "NDBase" :
728+ if not isinstance (data , dict ):
677729 raise ValueError ('value must be dict' )
678- #Catch keyerror to clean up error messages
679- #Only raise if they both fail
680- try :
681- return NDTool .build (value )
682- except KeyError as e1 :
730+ errors = []
731+ for cl in cls .get_union_types ():
683732 try :
684- return NDClassification .build (value )
685- except KeyError as e2 :
686- raise ValueError (
687- f'Unable to construct tool or classification.\n Tool: { e1 } \n Classification: { e2 } '
688- )
733+ return cl (** data )
734+ except KeyError as e :
735+ errors .append (f"{ cl .__name__ } : { e } " )
736+
737+ raise ValueError ('Unable to construct any annotation.\n {}' .format (
738+ "\n " .join (errors )))
689739
690- def validate (self , valid_datarows , valid_feature_schemas ):
691- self .data .validate_feature_schemas (valid_feature_schemas )
692- self .data .validate_datarow (valid_datarows )
740+ @classmethod
741+ def schema (cls ):
742+ data = {'definitions' : {}}
743+ for type_ in cls .get_union_types ():
744+ schema_ = type_ .schema ()
745+ data ['definitions' ].update (schema_ .pop ('definitions' ))
746+ data [type_ .__name__ ] = schema_
747+ return data
0 commit comments