Skip to content

Commit f997496

Browse files
authored
Vb/support pydantic2 plt 145 (#1412)
2 parents 35770e2 + ce3ef4c commit f997496

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+302
-256
lines changed

labelbox/data/annotation_types/base_annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import abc
22
from uuid import UUID
33
from typing import Any, Dict, Optional
4-
from pydantic import PrivateAttr
4+
from labelbox import pydantic_compat
55

66
from .feature import FeatureSchema
77

88

99
class BaseAnnotation(FeatureSchema, abc.ABC):
1010
""" Base annotation class. Shouldn't be directly instantiated
1111
"""
12-
_uuid: Optional[UUID] = PrivateAttr()
12+
_uuid: Optional[UUID] = pydantic_compat.PrivateAttr()
1313
extra: Dict[str, Any] = {}
1414

1515
def __init__(self, **data):

labelbox/data/annotation_types/classification/classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
except:
1010
from typing_extensions import Literal
1111

12-
from pydantic import BaseModel, validator
12+
from labelbox import pydantic_compat
1313
from ..feature import FeatureSchema
1414

1515

1616
# TODO: Replace when pydantic adds support for unions that don't coerce types
17-
class _TempName(ConfidenceMixin, BaseModel):
17+
class _TempName(ConfidenceMixin, pydantic_compat.BaseModel):
1818
name: str
1919

2020
def dict(self, *args, **kwargs):
@@ -47,7 +47,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
4747
return res
4848

4949

50-
class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel):
50+
class Radio(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
5151
""" A classification with only one selected option allowed
5252
5353
>>> Radio(answer = ClassificationAnswer(name = "dog"))
@@ -66,7 +66,7 @@ class Checklist(_TempName):
6666
answer: List[ClassificationAnswer]
6767

6868

69-
class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel):
69+
class Text(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
7070
""" Free form text
7171
7272
>>> Text(answer = "some text answer")

labelbox/data/annotation_types/data/base_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from abc import ABC
22
from typing import Optional, Dict, List, Any
33

4-
from pydantic import BaseModel
4+
from labelbox import pydantic_compat
55

66

7-
class BaseData(BaseModel, ABC):
7+
class BaseData(pydantic_compat.BaseModel, ABC):
88
"""
99
Base class for objects representing data.
1010
This class shouldn't directly be used

labelbox/data/annotation_types/data/raster.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55

66
from PIL import Image
77
from google.api_core import retry
8-
from pydantic import BaseModel
9-
from pydantic import root_validator
108
from requests.exceptions import ConnectTimeout
119
import requests
1210
import numpy as np
1311

12+
from labelbox import pydantic_compat
1413
from labelbox.exceptions import InternalServerError
1514
from .base_data import BaseData
1615
from ..types import TypedArray
1716

1817

19-
class RasterData(BaseModel, ABC):
18+
class RasterData(pydantic_compat.BaseModel, ABC):
2019
"""Represents an image or segmentation mask.
2120
"""
2221
im_bytes: Optional[bytes] = None
@@ -156,7 +155,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
156155
"One of url, im_bytes, file_path, arr must not be None.")
157156
return self.url
158157

159-
@root_validator()
158+
@pydantic_compat.root_validator()
160159
def validate_args(cls, values):
161160
file_path = values.get("file_path")
162161
im_bytes = values.get("im_bytes")

labelbox/data/annotation_types/data/text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import requests
44
from requests.exceptions import ConnectTimeout
55
from google.api_core import retry
6-
from pydantic import root_validator
76

7+
from labelbox import pydantic_compat
88
from labelbox.exceptions import InternalServerError
99
from labelbox.typing_imports import Literal
1010
from labelbox.utils import _NoCoercionMixin
@@ -90,7 +90,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> None:
9090
"One of url, im_bytes, file_path, numpy must not be None.")
9191
return self.url
9292

93-
@root_validator
93+
@pydantic_compat.root_validator
9494
def validate_date(cls, values):
9595
file_path = values.get("file_path")
9696
text = values.get("text")

labelbox/data/annotation_types/data/tiled_image.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from PIL import Image
1313
from pyproj import Transformer
1414
from pygeotile.point import Point as PygeoPoint
15-
from pydantic import BaseModel, validator
16-
from pydantic.class_validators import root_validator
15+
from labelbox import pydantic_compat
1716

1817
from labelbox.data.annotation_types import Rectangle, Point, Line, Polygon
1918
from .base_data import BaseData
@@ -41,7 +40,7 @@ class EPSG(Enum):
4140
EPSG3857 = 3857
4241

4342

44-
class TiledBounds(BaseModel):
43+
class TiledBounds(pydantic_compat.BaseModel):
4544
""" Bounds for a tiled image asset related to the relevant epsg.
4645
4746
Bounds should be Point objects.
@@ -55,7 +54,7 @@ class TiledBounds(BaseModel):
5554
epsg: EPSG
5655
bounds: List[Point]
5756

58-
@validator('bounds')
57+
@pydantic_compat.validator('bounds')
5958
def validate_bounds_not_equal(cls, bounds):
6059
first_bound = bounds[0]
6160
second_bound = bounds[1]
@@ -67,7 +66,7 @@ def validate_bounds_not_equal(cls, bounds):
6766
return bounds
6867

6968
#validate bounds are within lat,lng range if they are EPSG4326
70-
@root_validator
69+
@pydantic_compat.root_validator
7170
def validate_bounds_lat_lng(cls, values):
7271
epsg = values.get('epsg')
7372
bounds = values.get('bounds')
@@ -83,7 +82,7 @@ def validate_bounds_lat_lng(cls, values):
8382
return values
8483

8584

86-
class TileLayer(BaseModel):
85+
class TileLayer(pydantic_compat.BaseModel):
8786
""" Url that contains the tile layer. Must be in the format:
8887
8988
https://c.tile.openstreetmap.org/{z}/{x}/{y}.png
@@ -99,7 +98,7 @@ class TileLayer(BaseModel):
9998
def asdict(self) -> Dict[str, str]:
10099
return {"tileLayerUrl": self.url, "name": self.name}
101100

102-
@validator('url')
101+
@pydantic_compat.validator('url')
103102
def validate_url(cls, url):
104103
xyz_format = "/{z}/{x}/{y}"
105104
if xyz_format not in url:
@@ -344,7 +343,7 @@ def _validate_num_tiles(self, xstart: float, ystart: float, xend: float,
344343
f"Max allowed tiles are {max_tiles}"
345344
f"Increase max tiles or reduce zoom level.")
346345

347-
@validator('zoom_levels')
346+
@pydantic_compat.validator('zoom_levels')
348347
def validate_zoom_levels(cls, zoom_levels):
349348
if zoom_levels[0] > zoom_levels[1]:
350349
raise ValueError(
@@ -353,7 +352,7 @@ def validate_zoom_levels(cls, zoom_levels):
353352
return zoom_levels
354353

355354

356-
class EPSGTransformer(BaseModel):
355+
class EPSGTransformer(pydantic_compat.BaseModel):
357356
"""Transformer class between different EPSG's. Useful when wanting to project
358357
in different formats.
359358
"""

labelbox/data/annotation_types/data/video.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import cv2
99
import numpy as np
1010
from google.api_core import retry
11-
from pydantic import root_validator
1211

1312
from .base_data import BaseData
1413
from ..types import TypedArray
1514

15+
from labelbox import pydantic_compat
16+
1617
logger = logging.getLogger(__name__)
1718

1819

@@ -147,7 +148,7 @@ def frames_to_video(self,
147148
out.release()
148149
return file_path
149150

150-
@root_validator
151+
@pydantic_compat.root_validator
151152
def validate_data(cls, values):
152153
file_path = values.get("file_path")
153154
url = values.get("url")

labelbox/data/annotation_types/feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Optional
22

3-
from pydantic import BaseModel, root_validator
3+
from labelbox import pydantic_compat
44

55
from .types import Cuid
66

77

8-
class FeatureSchema(BaseModel):
8+
class FeatureSchema(pydantic_compat.BaseModel):
99
"""
1010
Class that represents a feature schema.
1111
Could be a annotation, a subclass, or an option.
@@ -14,7 +14,7 @@ class FeatureSchema(BaseModel):
1414
name: Optional[str] = None
1515
feature_schema_id: Optional[Cuid] = None
1616

17-
@root_validator
17+
@pydantic_compat.root_validator
1818
def must_set_one(cls, values):
1919
if values['feature_schema_id'] is None and values['name'] is None:
2020
raise ValueError(

labelbox/data/annotation_types/geometry/geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
import geojson
55
import numpy as np
6-
from pydantic import BaseModel
6+
from labelbox import pydantic_compat
77

88
from shapely import geometry as geom
99

1010

11-
class Geometry(BaseModel, ABC):
11+
class Geometry(pydantic_compat.BaseModel, ABC):
1212
"""Abstract base class for geometry objects
1313
"""
1414
extra: Dict[str, Any] = {}

labelbox/data/annotation_types/geometry/line.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import geojson
44
import numpy as np
55
import cv2
6-
from pydantic import validator
76

87
from shapely.geometry import LineString as SLineString
98

109
from .point import Point
1110
from .geometry import Geometry
1211

12+
from labelbox import pydantic_compat
13+
1314

1415
class Line(Geometry):
1516
"""Line annotation
@@ -64,7 +65,7 @@ def draw(self,
6465
color=color,
6566
thickness=thickness)
6667

67-
@validator('points')
68+
@pydantic_compat.validator('points')
6869
def is_geom_valid(cls, points):
6970
if len(points) < 2:
7071
raise ValueError(

0 commit comments

Comments
 (0)