Skip to content

Commit 577adb8

Browse files
author
Matt Sokoloff
committed
full schema support
1 parent ab76944 commit 577adb8

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33
import logging
44
import time
55
from pathlib import Path
6+
import typing
67
from uuid import UUID, uuid4
78
from pydantic import BaseModel, validator
89
import pydantic
910
import backoff
1011
import ndjson
1112
from pydantic.types import conlist, constr
1213
from pydantic import Required
14+
from pydantic.dataclasses import dataclass
15+
import labelbox
1316
import requests
1417
from labelbox import utils
15-
import labelbox.exceptions
18+
1619
from labelbox.orm import query
1720
from labelbox.orm.db_object import DbObject
1821
from labelbox.orm.model import Field
1922
from labelbox.orm.model import Relationship
2023
from labelbox.schema.enums import BulkImportRequestState
2124
from pydantic import ValidationError
22-
from typing import Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, Type, Set
25+
from typing import Any, Generic, List, Optional, BinaryIO, Dict, Iterable, Tuple, TypeVar, Union, Type, Set
2326
from typing_extensions import TypedDict, Literal
2427

2528
NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -430,22 +433,39 @@ class VideoSupported(BaseModel):
430433

431434

432435
class UnionConstructor:
433-
types: Iterable[Type["NDBase"]]
436+
437+
def __new__(cls, **kwargs):
438+
return cls.build(kwargs)
434439

435440
@classmethod
436441
def __get_validators__(cls):
437442
yield cls.build
438443

439444
@classmethod
440-
def build(cls, data) -> "NDBase":
445+
def get_union_types(cls):
446+
if not issubclass(cls, UnionConstructor):
447+
raise TypeError("{} must be a subclass of UnionConstructor")
448+
449+
union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
450+
if len(union_types) < 1:
451+
raise TypeError(
452+
"Class {cls} should inherit from a union of objects to build")
453+
if len(union_types) > 1:
454+
raise TypeError(
455+
f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}"
456+
)
457+
return union_types[0].__args__[0].__args__
458+
459+
@classmethod
460+
def build(cls: Any, data) -> "NDBase":
441461
if isinstance(data, BaseModel):
442462
data = data.dict()
443463

444464
top_level_fields = []
445465
max_match = 0
446466
matched = None
447467

448-
for type_ in cls.types:
468+
for type_ in cls.get_union_types():
449469
determinate_fields = type_.Config.determinants(type_)
450470
top_level_fields.append(determinate_fields)
451471
matches = sum([val in determinate_fields for val in data])
@@ -470,6 +490,11 @@ def build(cls, data) -> "NDBase":
470490
f"Expected classes with values {data} to have keys matching one of the following : {top_level_fields}"
471491
)
472492

493+
@classmethod
494+
def schema(cls):
495+
for cl in cls.get_union_types():
496+
print(cl.schema())
497+
473498

474499
class DataRow(BaseModel):
475500
id: str
@@ -564,16 +589,17 @@ def validate_feature_schemas(self, valid_feature_schemas):
564589
)
565590

566591

567-
class NDClassification(UnionConstructor):
568-
#Represents both subclasses and top level classifications
569-
types: Iterable[Type[NDBase]] = {NDText, NDRadio, NDCheckList}
592+
class NDClassification(UnionConstructor,
593+
Type[Union[NDText, NDRadio,
594+
NDCheckList]]): # type: ignore
595+
...
570596

571597

572598
###### Tools ######
573599

574600

575601
class BaseTool(NDBase):
576-
classifications: List["NDClassification"] = []
602+
classifications: List[NDClassification] = []
577603

578604
#This is indepdent of our problem
579605
def validate_feature_schemas(self, valid_feature_schemas):
@@ -681,16 +707,12 @@ def is_valid_mask(cls, v):
681707
return v
682708

683709

684-
class NDTool(UnionConstructor):
685-
#Tools and top level classifications
686-
types: Iterable[Type[NDBase]] = {
687-
NDMask,
688-
NDTextEntity,
689-
NDPoint,
690-
NDRectangle,
691-
NDPolyline,
692-
NDPolygon,
693-
}
710+
class NDTool(
711+
UnionConstructor,
712+
Type[Union[NDMask, # type: ignore
713+
NDTextEntity, NDPoint, NDRectangle, NDPolyline,
714+
NDPolygon,]]):
715+
...
694716

695717

696718
#### Top level annotation. Can be used to construct and validate any annotation
@@ -704,10 +726,10 @@ def validate_data(cls, value):
704726
#Catch keyerror to clean up error messages
705727
#Only raise if they both fail
706728
try:
707-
return NDTool.build(value)
729+
return NDTool(**value)
708730
except KeyError as e1:
709731
try:
710-
return NDClassification.build(value)
732+
return NDClassification(**value)
711733
except KeyError as e2:
712734
raise ValueError(
713735
f'Unable to construct tool or classification.\nTool: {e1}\nClassification: {e2}'

0 commit comments

Comments
 (0)