99import ndjson
1010import requests
1111from pydantic import BaseModel , validator
12- from pydantic import ValidationError
13- from typing_extensions import TypedDict , Literal
12+ from typing_extensions import Literal
1413from typing import (Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union ,
1514 Type , Set )
1615
@@ -189,6 +188,8 @@ def create_from_url(cls,
189188 project_id (str): id of project for which predictions will be imported
190189 name (str): name of BulkImportRequest
191190 url (str): publicly accessible URL pointing to ndjson file containing predictions
191+ validate (bool): a flag indicating if there should be a validation
192+ if `url` is valid ndjson
192193 Returns:
193194 BulkImportRequest object
194195 """
@@ -241,11 +242,13 @@ def create_from_objects(cls,
241242 }
242243 }``
243244
244- Args:x
245+ Args:
245246 client (Client): a Labelbox client
246247 project_id (str): id of project for which predictions will be imported
247248 name (str): name of BulkImportRequest
248249 predictions (Iterable[dict]): iterable of dictionaries representing predictions
250+ validate (bool): a flag indicating if there should be a validation
251+ if `predictions` is valid ndjson
249252 Returns:
250253 BulkImportRequest object
251254 """
@@ -313,7 +316,8 @@ def create_from_local_file(cls,
313316 return cls (client , response_data ["createBulkImportRequest" ])
314317
315318
316- def _validate_ndjson (lines : Iterable [Dict [str , Any ]], project ) -> None :
319+ def _validate_ndjson (lines : Iterable [Dict [str , Any ]],
320+ project : "labelbox.Project" ) -> None :
317321 """
318322 Client side validation of an ndjson object.
319323
@@ -328,7 +332,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
328332 project (Project): id of project for which predictions will be imported
329333
330334 Raises:
331- ValidationError : Raise for invalid NDJson
335+ MALValidationError : Raise for invalid NDJson
332336 UuidError: Duplicate UUID in upload
333337 """
334338 data_row_ids = {
@@ -347,8 +351,8 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
347351 f'{ uuid } already used in this import job, '
348352 'must be unique for the project.' )
349353 uids .add (uuid )
350- except (ValidationError , ValueError , TypeError , KeyError ) as e :
351- raise labelbox .exceptions .ValidationError (
354+ except (pydantic . ValidationError , ValueError , TypeError , KeyError ) as e :
355+ raise labelbox .exceptions .MALValidationError (
352356 f"Invalid NDJson on line { idx } " ) from e
353357
354358
@@ -406,19 +410,19 @@ def get_mal_schemas(ontology):
406410LabelboxID : str = pydantic .Field (..., min_length = 25 , max_length = 25 )
407411
408412
409- class Bbox (TypedDict ):
413+ class Bbox (BaseModel ):
410414 top : float
411415 left : float
412416 height : float
413417 width : float
414418
415419
416- class Point (TypedDict ):
420+ class Point (BaseModel ):
417421 x : float
418422 y : float
419423
420424
421- class FrameLocation (TypedDict ):
425+ class FrameLocation (BaseModel ):
422426 end : int
423427 start : int
424428
@@ -428,7 +432,9 @@ class VideoSupported(BaseModel):
428432 frames : Optional [List [FrameLocation ]]
429433
430434
431- class UnionConstructor :
435+ #Base class for a special kind of union.
436+ # Compatible with pydantic. Improves error messages over a traditional union
437+ class SpecialUnion :
432438
433439 def __new__ (cls , ** kwargs ):
434440 return cls .build (kwargs )
@@ -439,8 +445,8 @@ def __get_validators__(cls):
439445
440446 @classmethod
441447 def get_union_types (cls ):
442- if not issubclass (cls , UnionConstructor ):
443- raise TypeError ("{} must be a subclass of UnionConstructor " )
448+ if not issubclass (cls , SpecialUnion ):
449+ raise TypeError ("{} must be a subclass of SpecialUnion " )
444450
445451 union_types = [x for x in cls .__orig_bases__ if hasattr (x , "__args__" )]
446452 if len (union_types ) < 1 :
@@ -453,7 +459,16 @@ def get_union_types(cls):
453459 return union_types [0 ].__args__ [0 ].__args__
454460
455461 @classmethod
456- def build (cls : Any , data ) -> "NDBase" :
462+ def build (cls : Any , data : Union [dict , BaseModel ]) -> "NDBase" :
463+ """
464+ Checks through all objects in the union to see which matches the input data.
465+ Args:
466+ data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union
467+ raises:
468+ KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
469+ ValidationError: Error while trying to construct a specific object in the union
470+
471+ """
457472 if isinstance (data , BaseModel ):
458473 data = data .dict ()
459474
@@ -506,7 +521,6 @@ class NDFeatureSchema(BaseModel):
506521
507522class NDBase (NDFeatureSchema ):
508523 ontology_type : str
509- schemaId : str = LabelboxID
510524 uuid : UUID
511525 dataRow : DataRow
512526
@@ -553,7 +567,7 @@ class NDText(NDBase):
553567 #No feature schema to check
554568
555569
556- class NDCheckList (VideoSupported , NDBase ):
570+ class NDChecklist (VideoSupported , NDBase ):
557571 ontology_type : Literal ["checklist" ] = "checklist"
558572 answers : List [NDFeatureSchema ] = pydantic .Field (determinant = True )
559573
@@ -566,7 +580,7 @@ def validate_answers(cls, value, field):
566580
567581 def validate_feature_schemas (self , valid_feature_schemas ):
568582 #Test top level feature schema for this tool
569- super (NDCheckList , self ).validate_feature_schemas (valid_feature_schemas )
583+ super (NDChecklist , self ).validate_feature_schemas (valid_feature_schemas )
570584 #Test the feature schemas provided to the answer field
571585 if len (set ([answer .schemaId for answer in self .answers ])) != len (
572586 self .answers ):
@@ -595,9 +609,9 @@ def validate_feature_schemas(self, valid_feature_schemas):
595609
596610#A union with custom construction logic to improve error messages
597611class NDClassification (
598- UnionConstructor ,
612+ SpecialUnion ,
599613 Type [Union [NDText , NDRadio , # type: ignore
600- NDCheckList ]]):
614+ NDChecklist ]]):
601615 ...
602616
603617
@@ -665,7 +679,7 @@ class NDPoint(NDBaseTool):
665679 #Could check if points are positive
666680
667681
668- class EntityLocation (TypedDict ):
682+ class EntityLocation (BaseModel ):
669683 start : int
670684 end : int
671685
@@ -676,6 +690,9 @@ class NDTextEntity(NDBaseTool):
676690
677691 @validator ('location' )
678692 def is_valid_location (cls , v ):
693+ if isinstance (v , BaseModel ):
694+ v = v .dict ()
695+
679696 if len (v ) < 2 :
680697 raise ValueError (
681698 f"A line must have at least 2 points to be valid. Found { v } " )
@@ -687,7 +704,7 @@ def is_valid_location(cls, v):
687704 return v
688705
689706
690- class MaskFeatures (TypedDict ):
707+ class MaskFeatures (BaseModel ):
691708 instanceURI : str
692709 colorRGB : Union [List [int ], Tuple [int , int , int ]]
693710
@@ -698,6 +715,9 @@ class NDMask(NDBaseTool):
698715
699716 @validator ('mask' )
700717 def is_valid_mask (cls , v ):
718+ if isinstance (v , BaseModel ):
719+ v = v .dict ()
720+
701721 colors = v ['colorRGB' ]
702722 #Does the dtype matter? Can it be a float?
703723 if not isinstance (colors , (tuple , list )):
@@ -715,15 +735,17 @@ def is_valid_mask(cls, v):
715735
716736#A union with custom construction logic to improve error messages
717737class NDTool (
718- UnionConstructor ,
738+ SpecialUnion ,
719739 Type [Union [NDMask , # type: ignore
720740 NDTextEntity , NDPoint , NDRectangle , NDPolyline ,
721741 NDPolygon ,]]):
722742 ...
723743
724744
725- class NDAnnotation (UnionConstructor ,
726- Type [Union [NDTool , NDClassification ]]): # type: ignore
745+ class NDAnnotation (
746+ SpecialUnion ,
747+ Type [Union [NDTool , # type: ignore
748+ NDClassification ]]):
727749
728750 @classmethod
729751 def build (cls : Any , data ) -> "NDBase" :
0 commit comments