Skip to content

Commit 5bf9888

Browse files
OlegWockm1so
andauthored
fix: Multiple fixes for handling different data types in pandas columns analysis (#19)
* fix: Multiple fixes for handling different data types in pandas column analysis * Fix tests * chore: Improve test coverage * fix: Don't show bytes as base64 encoded strings * fix: Use the same logic for Spark * chore: Incorporate PR review suggestions --------- Co-authored-by: Michal Baumgartner <michal.baumgartner@deepnote.com>
1 parent 67a97b4 commit 5bf9888

File tree

7 files changed

+678
-45
lines changed

7 files changed

+678
-45
lines changed

.cursorrules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ Additional for integration tests:
9494
# Run local tests
9595
./bin/test-local
9696

97+
# Run a specific test file
98+
./bin/test-local tests/unit/test_file.py
99+
100+
# ... or specific test from file
101+
./bin/test-local tests/unit/test_file.py::TestClass::test_method
102+
97103
# Run specific test type
98104
export TEST_TYPE="unit|integration"
99105
export TOOLKIT_VERSION="local-build"

deepnote_toolkit/ocelots/pandas/analyze.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import pandas as pd
77

88
from deepnote_toolkit.ocelots.constants import DEEPNOTE_INDEX_COLUMN
9+
from deepnote_toolkit.ocelots.pandas.utils import (
10+
is_numeric_or_temporal,
11+
is_type_datetime_or_timedelta,
12+
safe_convert_to_string,
13+
)
914
from deepnote_toolkit.ocelots.types import ColumnsStatsRecord, ColumnStats
1015

1116

@@ -24,7 +29,10 @@ def _get_categories(np_array):
2429
# special treatment for empty values
2530
num_nans = pandas_series.isna().sum().item()
2631

27-
counter = Counter(pandas_series.dropna().astype(str))
32+
try:
33+
counter = Counter(pandas_series.dropna().astype(str))
34+
except (TypeError, UnicodeDecodeError, AttributeError):
35+
counter = Counter(pandas_series.dropna().apply(safe_convert_to_string))
2836

2937
max_items = 3
3038
if num_nans > 0:
@@ -46,33 +54,9 @@ def _get_categories(np_array):
4654
return [{"name": name, "count": count} for name, count in categories]
4755

4856

49-
def _is_type_numeric(dtype):
50-
"""
51-
Returns True if dtype is numeric, False otherwise
52-
53-
Numeric means either a number (int, float, complex) or a datetime or timedelta.
54-
It means e.g. that a range of these values can be plotted on a histogram.
55-
"""
56-
57-
# datetime doesn't play nice with np.issubdtype, so we need to check explicitly
58-
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
59-
dtype
60-
):
61-
return True
62-
63-
try:
64-
return np.issubdtype(dtype, np.number)
65-
except TypeError:
66-
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
67-
return False
68-
69-
7057
def _get_histogram(pd_series):
7158
try:
72-
if pd.api.types.is_datetime64_any_dtype(
73-
pd_series
74-
) or pd.api.types.is_timedelta64_dtype(pd_series):
75-
# convert datetime or timedelta to an integer so that a histogram can be created
59+
if is_type_datetime_or_timedelta(pd_series):
7660
np_array = np.array(pd_series.dropna().astype(int))
7761
else:
7862
# let's drop infinite values because they break histograms
@@ -104,11 +88,15 @@ def _calculate_min_max(column):
10488
"""
10589
Calculate min and max values for a given column.
10690
"""
107-
if _is_type_numeric(column.dtype):
91+
if not is_numeric_or_temporal(column.dtype):
92+
return None, None
93+
94+
try:
10895
min_value = str(min(column.dropna())) if len(column.dropna()) > 0 else None
10996
max_value = str(max(column.dropna())) if len(column.dropna()) > 0 else None
11097
return min_value, max_value
111-
return None, None
98+
except (TypeError, ValueError):
99+
return None, None
112100

113101

114102
def analyze_columns(
@@ -167,7 +155,7 @@ def analyze_columns(
167155
unique_count=_count_unique(column), nan_count=column.isnull().sum().item()
168156
)
169157

170-
if _is_type_numeric(column.dtype):
158+
if is_numeric_or_temporal(column.dtype):
171159
min_value, max_value = _calculate_min_max(column)
172160
columns[i].stats.min = min_value
173161
columns[i].stats.max = max_value
@@ -187,7 +175,7 @@ def analyze_columns(
187175
for i in range(max_columns_to_analyze, len(df.columns)):
188176
# Ignore columns that are not numeric
189177
column = df.iloc[:, i]
190-
if not _is_type_numeric(column.dtype):
178+
if not is_numeric_or_temporal(column.dtype):
191179
continue
192180

193181
column_name = columns[i].name

deepnote_toolkit/ocelots/pandas/utils.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@
55
from deepnote_toolkit.ocelots.constants import MAX_STRING_CELL_LENGTH
66

77

8+
def safe_convert_to_string(value):
9+
"""
10+
Safely convert a value to string, handling cases where str() might fail.
11+
12+
Note: For bytes, this returns Python's standard string representation (e.g., b'hello')
13+
rather than base64 encoding, which is more human-readable.
14+
"""
15+
try:
16+
return str(value)
17+
except Exception:
18+
return "<unconvertible>"
19+
20+
821
# like fillna, but only fills NaT (not a time) values in datetime columns with the specified value
922
def fill_nat(df, value):
1023
df_datetime_columns = df.select_dtypes(
@@ -76,36 +89,63 @@ def deduplicate_columns(df):
7689
# Cast dataframe contents to strings and trim them to avoid sending too much data
7790
def cast_objects_to_string(df):
7891
def to_string_truncated(elem):
79-
elem_string = str(elem)
92+
elem_string = safe_convert_to_string(elem)
8093
return (
8194
(elem_string[: MAX_STRING_CELL_LENGTH - 1] + "…")
8295
if len(elem_string) > MAX_STRING_CELL_LENGTH
8396
else elem_string
8497
)
8598

8699
for column in df:
87-
if not _is_type_number(df[column].dtype):
100+
if not is_pure_numeric(df[column].dtype):
88101
# if the dtype is not a number, we want to convert it to string and truncate
89102
df[column] = df[column].apply(to_string_truncated)
90103

91104
return df
92105

93106

94-
def _is_type_number(dtype):
107+
def is_type_datetime_or_timedelta(series_or_dtype):
95108
"""
96-
Returns True if dtype is a number, False otherwise. Datetime and timedelta will return False.
109+
Returns True if the series or dtype is datetime or timedelta, False otherwise.
110+
"""
111+
return pd.api.types.is_datetime64_any_dtype(
112+
series_or_dtype
113+
) or pd.api.types.is_timedelta64_dtype(series_or_dtype)
114+
97115

98-
The primary intent of this is to recognize a value that will converted to a JSON number during serialization.
116+
def is_numeric_or_temporal(dtype):
99117
"""
118+
Returns True if dtype is numeric or temporal (datetime/timedelta), False otherwise.
100119
101-
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
102-
dtype
103-
):
120+
This includes numbers (int, float), datetime, and timedelta types.
121+
Use this to determine if values can be plotted on a histogram or have min/max calculated.
122+
"""
123+
if is_type_datetime_or_timedelta(dtype):
124+
return True
125+
126+
try:
127+
return np.issubdtype(dtype, np.number) and not np.issubdtype(
128+
dtype, np.complexfloating
129+
)
130+
except TypeError:
131+
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
132+
return False
133+
134+
135+
def is_pure_numeric(dtype):
136+
"""
137+
Returns True if dtype is a pure number (int, float), False otherwise.
138+
139+
Use this to determine if a value will be serialized as a JSON number.
140+
"""
141+
if is_type_datetime_or_timedelta(dtype):
104142
# np.issubdtype(dtype, np.number) returns True for timedelta, which we don't want
105143
return False
106144

107145
try:
108-
return np.issubdtype(dtype, np.number)
146+
return np.issubdtype(dtype, np.number) and not np.issubdtype(
147+
dtype, np.complexfloating
148+
)
109149
except TypeError:
110150
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
111151
return False

deepnote_toolkit/ocelots/pyspark/implementation.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,22 @@ def to_records(self, mode: Literal["json", "python"]) -> List[Dict[str, Any]]:
232232
StructField,
233233
)
234234

235+
def binary_to_string_repr(
236+
binary_data: Optional[Union[bytes, bytearray]]
237+
) -> Optional[str]:
238+
"""Convert binary data to Python string representation (e.g., b'hello').
239+
240+
Args:
241+
binary_data: Binary data as bytes or bytearray. PySpark passes BinaryType
242+
as bytearray by default, but Spark 4.1+ with
243+
spark.sql.execution.pyspark.binaryAsBytes=true passes bytes instead.
244+
"""
245+
if binary_data is None:
246+
return None
247+
return str(bytes(binary_data))
248+
249+
binary_udf = F.udf(binary_to_string_repr, StringType())
250+
235251
def select_column(field: StructField) -> Column:
236252
col = F.col(field.name)
237253
# Numbers are already JSON-serialise, except Decimal
@@ -240,11 +256,12 @@ def select_column(field: StructField) -> Column:
240256
):
241257
return col
242258

243-
# We slice binary field before encoding to avoid encoding potentially big blob. Round slicing to
244-
# 4 bytes to avoid breaking multi-byte sequences
259+
# We slice binary field before converting to string representation
245260
if isinstance(field.dataType, BinaryType):
246-
sliced = F.substring(field, 1, keep_bytes)
247-
return F.base64(sliced)
261+
# Each byte becomes up to 4 chars (\xNN) in string repr, plus b'' overhead
262+
max_binary_bytes = (MAX_STRING_CELL_LENGTH - 3) // 4
263+
sliced = F.substring(F.col(field.name), 1, max_binary_bytes)
264+
return binary_udf(sliced)
248265

249266
# String just needs to be trimmed
250267
if isinstance(field.dataType, StringType):
@@ -253,8 +270,6 @@ def select_column(field: StructField) -> Column:
253270
# Everything else gets stringified (Decimal, Date, Timestamp, Struct, …)
254271
return F.substring(col.cast("string"), 1, MAX_STRING_CELL_LENGTH)
255272

256-
keep_bytes = (MAX_STRING_CELL_LENGTH // 4) * 3
257-
258273
if mode == "python":
259274
return [row.asDict() for row in self._df.collect()]
260275
elif mode == "json":

tests/unit/helpers/testing_dataframes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,14 @@ def create_dataframe_with_duplicate_column_names():
261261
datetime.datetime(2023, 1, 1, 12, 0, 0),
262262
datetime.datetime(2023, 1, 2, 12, 0, 0),
263263
],
264+
"binary": [b"hello", b"world"],
264265
}
265266
),
266267
"pyspark_schema": pst.StructType(
267268
[
268269
pst.StructField("list", pst.ArrayType(pst.IntegerType()), True),
269270
pst.StructField("datetime", pst.TimestampType(), True),
271+
pst.StructField("binary", pst.BinaryType(), True),
270272
]
271273
),
272274
},

0 commit comments

Comments
 (0)