11import json
2- import logging
32import time
4- from pathlib import Path
53from uuid import UUID , uuid4
6- from pydantic import BaseModel , validator
4+
5+ import logging
6+ from pathlib import Path
77import pydantic
88import backoff
99import ndjson
10- import labelbox
1110import requests
11+ from pydantic import BaseModel , validator
12+ from typing_extensions import Literal
13+ from typing import (Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union ,
14+ Type , Set )
15+
16+ import labelbox
1217from labelbox import utils
1318from labelbox .orm import query
1419from labelbox .orm .db_object import DbObject
15- from labelbox .orm .model import Field
16- from labelbox .orm .model import Relationship
20+ from labelbox .orm .model import Field , Relationship
1721from labelbox .schema .enums import BulkImportRequestState
18- from pydantic import ValidationError
19- from typing import Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union , Type , Set
20- from typing_extensions import TypedDict , Literal
2122
2223NDJSON_MIME_TYPE = "application/x-ndjson"
2324logger = logging .getLogger (__name__ )
@@ -187,6 +188,8 @@ def create_from_url(cls,
187188 project_id (str): id of project for which predictions will be imported
188189 name (str): name of BulkImportRequest
189190 url (str): publicly accessible URL pointing to ndjson file containing predictions
191+ validate (bool): a flag indicating if there should be a validation
192+ if `url` is valid ndjson
190193 Returns:
191194 BulkImportRequest object
192195 """
@@ -219,7 +222,7 @@ def create_from_objects(cls,
219222 client ,
220223 project_id : str ,
221224 name : str ,
222- predictions : Iterable [dict ],
225+ predictions : Iterable [Dict ],
223226 validate = True ) -> 'BulkImportRequest' :
224227 """
225228 Creates a `BulkImportRequest` from an iterable of dictionaries.
@@ -239,11 +242,13 @@ def create_from_objects(cls,
239242 }
240243 }``
241244
242- Args:x
245+ Args:
243246 client (Client): a Labelbox client
244247 project_id (str): id of project for which predictions will be imported
245248 name (str): name of BulkImportRequest
246249 predictions (Iterable[dict]): iterable of dictionaries representing predictions
250+ validate (bool): a flag indicating if there should be a validation
251+ if `predictions` is valid ndjson
247252 Returns:
248253 BulkImportRequest object
249254 """
@@ -311,7 +316,8 @@ def create_from_local_file(cls,
311316 return cls (client , response_data ["createBulkImportRequest" ])
312317
313318
314- def _validate_ndjson (lines : Iterable [Dict [str , Any ]], project ) -> None :
319+ def _validate_ndjson (lines : Iterable [Dict [str , Any ]],
320+ project : "labelbox.Project" ) -> None :
315321 """
316322 Client side validation of an ndjson object.
317323
@@ -326,7 +332,7 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
326332 project (Project): id of project for which predictions will be imported
327333
328334 Raises:
329- NDJsonError : Raise for invalid NDJson
335+ MALValidationError : Raise for invalid NDJson
330336 UuidError: Duplicate UUID in upload
331337 """
332338 data_row_ids = {
@@ -345,8 +351,8 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], project) -> None:
345351 f'{ uuid } already used in this import job, '
346352 'must be unique for the project.' )
347353 uids .add (uuid )
348- except (ValidationError , ValueError , TypeError , KeyError ) as e :
349- raise labelbox .exceptions .NDJsonError (
354+ except (pydantic . ValidationError , ValueError , TypeError , KeyError ) as e :
355+ raise labelbox .exceptions .MALValidationError (
350356 f"Invalid NDJson on line { idx } " ) from e
351357
352358
@@ -404,19 +410,19 @@ def get_mal_schemas(ontology):
404410LabelboxID : str = pydantic .Field (..., min_length = 25 , max_length = 25 )
405411
406412
407- class Bbox (TypedDict ):
413+ class Bbox (BaseModel ):
408414 top : float
409415 left : float
410416 height : float
411417 width : float
412418
413419
414- class Point (TypedDict ):
420+ class Point (BaseModel ):
415421 x : float
416422 y : float
417423
418424
419- class FrameLocation (TypedDict ):
425+ class FrameLocation (BaseModel ):
420426 end : int
421427 start : int
422428
@@ -426,7 +432,9 @@ class VideoSupported(BaseModel):
426432 frames : Optional [List [FrameLocation ]]
427433
428434
429- class UnionConstructor :
435+ #Base class for a special kind of union.
436+ # Compatible with pydantic. Improves error messages over a traditional union
437+ class SpecialUnion :
430438
431439 def __new__ (cls , ** kwargs ):
432440 return cls .build (kwargs )
@@ -437,8 +445,8 @@ def __get_validators__(cls):
437445
438446 @classmethod
439447 def get_union_types (cls ):
440- if not issubclass (cls , UnionConstructor ):
441- raise TypeError ("{} must be a subclass of UnionConstructor " )
448+ if not issubclass (cls , SpecialUnion ):
449+ raise TypeError ("{} must be a subclass of SpecialUnion " )
442450
443451 union_types = [x for x in cls .__orig_bases__ if hasattr (x , "__args__" )]
444452 if len (union_types ) < 1 :
@@ -451,7 +459,16 @@ def get_union_types(cls):
451459 return union_types [0 ].__args__ [0 ].__args__
452460
453461 @classmethod
454- def build (cls : Any , data ) -> "NDBase" :
462+ def build (cls : Any , data : Union [dict , BaseModel ]) -> "NDBase" :
463+ """
464+ Checks through all objects in the union to see which matches the input data.
465+ Args:
466+ data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union
467+ raises:
468+ KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
469+ ValidationError: Error while trying to construct a specific object in the union
470+
471+ """
455472 if isinstance (data , BaseModel ):
456473 data = data .dict ()
457474
@@ -504,7 +521,6 @@ class NDFeatureSchema(BaseModel):
504521
505522class NDBase (NDFeatureSchema ):
506523 ontology_type : str
507- schemaId : str = LabelboxID
508524 uuid : UUID
509525 dataRow : DataRow
510526
@@ -551,7 +567,7 @@ class NDText(NDBase):
551567 #No feature schema to check
552568
553569
554- class NDCheckList (VideoSupported , NDBase ):
570+ class NDChecklist (VideoSupported , NDBase ):
555571 ontology_type : Literal ["checklist" ] = "checklist"
556572 answers : List [NDFeatureSchema ] = pydantic .Field (determinant = True )
557573
@@ -564,7 +580,7 @@ def validate_answers(cls, value, field):
564580
565581 def validate_feature_schemas (self , valid_feature_schemas ):
566582 #Test top level feature schema for this tool
567- super (NDCheckList , self ).validate_feature_schemas (valid_feature_schemas )
583+ super (NDChecklist , self ).validate_feature_schemas (valid_feature_schemas )
568584 #Test the feature schemas provided to the answer field
569585 if len (set ([answer .schemaId for answer in self .answers ])) != len (
570586 self .answers ):
@@ -593,9 +609,9 @@ def validate_feature_schemas(self, valid_feature_schemas):
593609
594610#A union with custom construction logic to improve error messages
595611class NDClassification (
596- UnionConstructor ,
612+ SpecialUnion ,
597613 Type [Union [NDText , NDRadio , # type: ignore
598- NDCheckList ]]):
614+ NDChecklist ]]):
599615 ...
600616
601617
@@ -663,7 +679,7 @@ class NDPoint(NDBaseTool):
663679 #Could check if points are positive
664680
665681
666- class EntityLocation (TypedDict ):
682+ class EntityLocation (BaseModel ):
667683 start : int
668684 end : int
669685
@@ -674,6 +690,9 @@ class NDTextEntity(NDBaseTool):
674690
675691 @validator ('location' )
676692 def is_valid_location (cls , v ):
693+ if isinstance (v , BaseModel ):
694+ v = v .dict ()
695+
677696 if len (v ) < 2 :
678697 raise ValueError (
679698 f"A line must have at least 2 points to be valid. Found { v } " )
@@ -685,7 +704,7 @@ def is_valid_location(cls, v):
685704 return v
686705
687706
688- class MaskFeatures (TypedDict ):
707+ class MaskFeatures (BaseModel ):
689708 instanceURI : str
690709 colorRGB : Union [List [int ], Tuple [int , int , int ]]
691710
@@ -696,6 +715,9 @@ class NDMask(NDBaseTool):
696715
697716 @validator ('mask' )
698717 def is_valid_mask (cls , v ):
718+ if isinstance (v , BaseModel ):
719+ v = v .dict ()
720+
699721 colors = v ['colorRGB' ]
700722 #Does the dtype matter? Can it be a float?
701723 if not isinstance (colors , (tuple , list )):
@@ -713,15 +735,17 @@ def is_valid_mask(cls, v):
713735
714736#A union with custom construction logic to improve error messages
715737class NDTool (
716- UnionConstructor ,
738+ SpecialUnion ,
717739 Type [Union [NDMask , # type: ignore
718740 NDTextEntity , NDPoint , NDRectangle , NDPolyline ,
719741 NDPolygon ,]]):
720742 ...
721743
722744
723- class NDAnnotation (UnionConstructor ,
724- Type [Union [NDTool , NDClassification ]]): # type: ignore
745+ class NDAnnotation (
746+ SpecialUnion ,
747+ Type [Union [NDTool , # type: ignore
748+ NDClassification ]]):
725749
726750 @classmethod
727751 def build (cls : Any , data ) -> "NDBase" :
0 commit comments