Skip to content

Commit 34774a1

Browse files
authored
Merge pull request #161 from dreadnode/fix/data-types-imports
fix: data type imports and typing
2 parents 8d42d4a + 9ebf8b3 commit 34774a1

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

dreadnode/data_types/image.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,32 @@
44
from pathlib import Path
55

66
import numpy as np
7-
import PIL
8-
import PIL.Image
97

108
from dreadnode.data_types.base import DataType
119

10+
if t.TYPE_CHECKING:
11+
import numpy as np
12+
1213
ImageDataType = t.Union[t.Any, "np.ndarray[t.Any, t.Any]"]
1314
ImageDataOrPathType = str | Path | bytes | ImageDataType
1415

1516

17+
def check_imports() -> None:
18+
try:
19+
import PIL # type: ignore[import-not-found]
20+
except ImportError as e:
21+
raise ImportError(
22+
"Image processing requires Pillow. Install with: pip install dreadnode[multimodal]"
23+
) from e
24+
25+
try:
26+
import numpy as np # type: ignore[import-not-found]
27+
except ImportError as e:
28+
raise ImportError(
29+
"Image processing requires NumPy. Install with: pip install dreadnode[multimodal]"
30+
) from e
31+
32+
1633
class Image(DataType):
1734
"""
1835
Image media type for Dreadnode logging.
@@ -45,6 +62,7 @@ def __init__(
4562
caption: Optional caption for the image
4663
format: Optional format to use when saving (png, jpg, etc.)
4764
"""
65+
check_imports()
4866
self._data = data
4967
self._mode = mode
5068
self._caption = caption
@@ -66,6 +84,8 @@ def _process_image_data(self) -> tuple[bytes, str, str | None, int | None, int |
6684
Returns:
6785
A tuple of (image_bytes, image_format, mode, width, height)
6886
"""
87+
import numpy as np # type: ignore[import-not-found]
88+
import PIL.Image # type: ignore[import-not-found]
6989

7090
if isinstance(self._data, str | Path) and Path(self._data).exists():
7191
return self._process_file_path()
@@ -85,6 +105,7 @@ def _process_file_path(self) -> tuple[bytes, str, str | None, int | None, int |
85105
Returns:
86106
A tuple of (image_bytes, image_format, mode, width, height)
87107
"""
108+
import PIL.Image # type: ignore[import-not-found]
88109

89110
path_str = str(self._data)
90111
image_bytes = Path(path_str).read_bytes()
@@ -102,6 +123,7 @@ def _process_pil_image(self) -> tuple[bytes, str, str | None, int | None, int |
102123
Returns:
103124
A tuple of (image_bytes, image_format, mode, width, height)
104125
"""
126+
import PIL.Image # type: ignore[import-not-found]
105127

106128
if not isinstance(self._data, PIL.Image.Image):
107129
raise TypeError(f"Expected PIL.Image, got {type(self._data)}")
@@ -139,6 +161,8 @@ def _process_numpy_array(self) -> tuple[bytes, str, str | None, int | None, int
139161
Returns:
140162
A tuple of (image_bytes, image_format, mode, width, height)
141163
"""
164+
import numpy as np # type: ignore[import-not-found]
165+
import PIL.Image # type: ignore[import-not-found]
142166

143167
buffer = io.BytesIO()
144168
image_format = self._format or "png"
@@ -168,6 +192,7 @@ def _process_raw_bytes(self) -> tuple[bytes, str, str | None, int | None, int |
168192
Returns:
169193
A tuple of (image_bytes, image_format, mode, width, height)
170194
"""
195+
import PIL.Image # type: ignore[import-not-found]
171196

172197
if not isinstance(self._data, bytes):
173198
raise TypeError(f"Expected bytes, got {type(self._data)}")
@@ -192,6 +217,7 @@ def _process_base64_string(self) -> tuple[bytes, str, str | None, int | None, in
192217
Returns:
193218
A tuple of (image_bytes, image_format, mode, width, height)
194219
"""
220+
import PIL.Image # type: ignore[import-not-found]
195221

196222
if not isinstance(self._data, str):
197223
raise TypeError(f"Expected str, got {type(self._data)}")
@@ -228,6 +254,8 @@ def _generate_metadata(
228254
self, image_format: str, mode: str | None, width: int | None, height: int | None
229255
) -> dict[str, str | int | None]:
230256
"""Generate metadata for the image."""
257+
import numpy as np # type: ignore[import-not-found]
258+
import PIL.Image # type: ignore[import-not-found]
231259

232260
metadata: dict[str, str | int | None] = {
233261
"extension": image_format.lower(),
@@ -286,6 +314,7 @@ def _ensure_valid_image_array(
286314
self, array: "np.ndarray[t.Any, np.dtype[t.Any]]"
287315
) -> "np.ndarray[t.Any, np.dtype[t.Any]]":
288316
"""Convert numpy array to a format suitable for PIL."""
317+
import numpy as np # type: ignore[import-not-found]
289318

290319
grayscale_dim = 2
291320
rgb_dim = 3

dreadnode/data_types/table.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,33 @@
33
from pathlib import Path
44
from typing import ClassVar
55

6-
import numpy as np
7-
import pandas as pd
8-
96
from dreadnode.data_types.base import DataType
107

8+
if t.TYPE_CHECKING:
9+
import numpy as np
10+
import pandas as pd
11+
1112
TableDataType = t.Union[
1213
"pd.DataFrame", dict[t.Any, t.Any], list[t.Any], str, Path, "np.ndarray[t.Any, t.Any]"
1314
]
1415

1516

17+
def check_imports() -> None:
18+
try:
19+
import pandas as pd # type: ignore[import-not-found]
20+
except ImportError as e:
21+
raise ImportError(
22+
"Image processing requires Pandas. Install with: pip install dreadnode[multimodal]"
23+
) from e
24+
25+
try:
26+
import numpy as np # type: ignore[import-not-found]
27+
except ImportError as e:
28+
raise ImportError(
29+
"Image processing requires NumPy. Install with: pip install dreadnode[multimodal]"
30+
) from e
31+
32+
1633
class Table(DataType):
1734
"""
1835
Table data type for Dreadnode logging.
@@ -47,6 +64,7 @@ def __init__(
4764
format: Optional format to use when saving (csv, parquet, json)
4865
index: Include index in the output
4966
"""
67+
check_imports()
5068
self._data = data
5169
self._caption = caption
5270
self._format = format or "csv" # Default to CSV
@@ -77,6 +95,8 @@ def _to_dataframe(self) -> "pd.DataFrame":
7795
Returns:
7896
A pandas DataFrame representation of the input data
7997
"""
98+
import numpy as np # type: ignore[import-not-found]
99+
import pandas as pd # type: ignore[import-not-found]
80100

81101
if isinstance(self._data, pd.DataFrame):
82102
return self._data
@@ -131,6 +151,8 @@ def _generate_metadata(self, data_frame: "pd.DataFrame") -> dict[str, t.Any]:
131151
Returns:
132152
A dictionary of metadata
133153
"""
154+
import numpy as np # type: ignore[import-not-found]
155+
import pandas as pd # type: ignore[import-not-found]
134156

135157
metadata = {
136158
"extension": self._format,

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ python_version = "3.10"
109109
exclude = "tests"
110110

111111
[[tool.mypy.overrides]]
112-
module = ["dreadnode.scorers.*", "dreadnode.transforms.*"]
112+
module = [
113+
"dreadnode.data_types.*",
114+
"dreadnode.scorers.*",
115+
"dreadnode.transforms.*",
116+
]
113117
disable_error_code = ["unused-ignore", "import-untyped"]
114118

115119
[tool.ty.environment]

0 commit comments

Comments
 (0)