33import logging
44import time
55from pathlib import Path
6+ import typing
67from uuid import UUID , uuid4
78from pydantic import BaseModel , validator
89import pydantic
910import backoff
1011import ndjson
1112from pydantic .types import conlist , constr
1213from pydantic import Required
14+ from pydantic .dataclasses import dataclass
15+ import labelbox
1316import requests
1417from labelbox import utils
15- import labelbox . exceptions
18+
1619from labelbox .orm import query
1720from labelbox .orm .db_object import DbObject
1821from labelbox .orm .model import Field
1922from labelbox .orm .model import Relationship
2023from labelbox .schema .enums import BulkImportRequestState
2124from pydantic import ValidationError
22- from typing import Any , List , Optional , BinaryIO , Dict , Iterable , Tuple , Union , Type , Set
25+ from typing import Any , Generic , List , Optional , BinaryIO , Dict , Iterable , Tuple , TypeVar , Union , Type , Set
2326from typing_extensions import TypedDict , Literal
2427
2528NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -430,22 +433,39 @@ class VideoSupported(BaseModel):
430433
431434
432435class UnionConstructor :
433- types : Iterable [Type ["NDBase" ]]
436+
437+ def __new__ (cls , ** kwargs ):
438+ return cls .build (kwargs )
434439
435440 @classmethod
436441 def __get_validators__ (cls ):
437442 yield cls .build
438443
439444 @classmethod
440- def build (cls , data ) -> "NDBase" :
445+ def get_union_types (cls ):
446+ if not issubclass (cls , UnionConstructor ):
447+ raise TypeError ("{} must be a subclass of UnionConstructor" )
448+
449+ union_types = [x for x in cls .__orig_bases__ if hasattr (x , "__args__" )]
450+ if len (union_types ) < 1 :
451+ raise TypeError (
452+ "Class {cls} should inherit from a union of objects to build" )
453+ if len (union_types ) > 1 :
454+ raise TypeError (
455+ f"Class { cls } should inherit from exactly one union of objects to build. Found { union_types } "
456+ )
457+ return union_types [0 ].__args__ [0 ].__args__
458+
459+ @classmethod
460+ def build (cls : Any , data ) -> "NDBase" :
441461 if isinstance (data , BaseModel ):
442462 data = data .dict ()
443463
444464 top_level_fields = []
445465 max_match = 0
446466 matched = None
447467
448- for type_ in cls .types :
468+ for type_ in cls .get_union_types () :
449469 determinate_fields = type_ .Config .determinants (type_ )
450470 top_level_fields .append (determinate_fields )
451471 matches = sum ([val in determinate_fields for val in data ])
@@ -470,6 +490,11 @@ def build(cls, data) -> "NDBase":
470490 f"Expected classes with values { data } to have keys matching one of the following : { top_level_fields } "
471491 )
472492
493+ @classmethod
494+ def schema (cls ):
495+ for cl in cls .get_union_types ():
496+ print (cl .schema ())
497+
473498
474499class DataRow (BaseModel ):
475500 id : str
@@ -564,16 +589,17 @@ def validate_feature_schemas(self, valid_feature_schemas):
564589 )
565590
566591
567- class NDClassification (UnionConstructor ):
568- #Represents both subclasses and top level classifications
569- types : Iterable [Type [NDBase ]] = {NDText , NDRadio , NDCheckList }
592+ class NDClassification (UnionConstructor ,
593+ Type [Union [NDText , NDRadio ,
594+ NDCheckList ]]): # type: ignore
595+ ...
570596
571597
572598###### Tools ######
573599
574600
575601class BaseTool (NDBase ):
576- classifications : List [" NDClassification" ] = []
602+ classifications : List [NDClassification ] = []
577603
578604 #This is indepdent of our problem
579605 def validate_feature_schemas (self , valid_feature_schemas ):
@@ -681,16 +707,12 @@ def is_valid_mask(cls, v):
681707 return v
682708
683709
684- class NDTool (UnionConstructor ):
685- #Tools and top level classifications
686- types : Iterable [Type [NDBase ]] = {
687- NDMask ,
688- NDTextEntity ,
689- NDPoint ,
690- NDRectangle ,
691- NDPolyline ,
692- NDPolygon ,
693- }
710+ class NDTool (
711+ UnionConstructor ,
712+ Type [Union [NDMask , # type: ignore
713+ NDTextEntity , NDPoint , NDRectangle , NDPolyline ,
714+ NDPolygon ,]]):
715+ ...
694716
695717
696718#### Top level annotation. Can be used to construct and validate any annotation
@@ -704,10 +726,10 @@ def validate_data(cls, value):
704726 #Catch keyerror to clean up error messages
705727 #Only raise if they both fail
706728 try :
707- return NDTool . build ( value )
729+ return NDTool ( ** value )
708730 except KeyError as e1 :
709731 try :
710- return NDClassification . build ( value )
732+ return NDClassification ( ** value )
711733 except KeyError as e2 :
712734 raise ValueError (
713735 f'Unable to construct tool or classification.\n Tool: { e1 } \n Classification: { e2 } '
0 commit comments