Skip to content

Commit b9b365e

Browse files
Copilotborchero
andauthored
fix: Properly type-check non-existent Schema attributes (#224)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> Co-authored-by: Oliver Borchert <oliver.borchert@quantco.com>
1 parent 3f601d1 commit b9b365e

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

dataframely/_base_schema.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import ABCMeta
99
from copy import copy
1010
from dataclasses import dataclass, field
11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any
1212

1313
import polars as pl
1414

@@ -164,12 +164,16 @@ def __new__(
164164

165165
return cls
166166

167-
def __getattribute__(cls, name: str) -> Any:
168-
val = super().__getattribute__(name)
169-
# Dynamically set the name of the column if it is a `Column` instance.
170-
if isinstance(val, Column):
171-
val._name = val.alias or name
172-
return val
167+
if not TYPE_CHECKING:
168+
# Only define __getattribute__ at runtime to allow type checkers to properly
169+
# validate attribute access. When TYPE_CHECKING is True, type checkers will use
170+
# the default metaclass behavior which correctly identifies non-existent attributes.
171+
def __getattribute__(cls, name: str) -> Any:
172+
val = super().__getattribute__(name)
173+
# Dynamically set the name of the column if it is a `Column` instance.
174+
if isinstance(val, Column):
175+
val._name = val.alias or name
176+
return val
173177

174178
@staticmethod
175179
def _get_metadata_recursively(kls: type[object]) -> Metadata:
@@ -199,9 +203,9 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
199203
def __repr__(cls) -> str:
200204
parts = [f'[Schema "{cls.__name__}"]']
201205
parts.append(textwrap.indent("Columns:", prefix=" " * 2))
202-
for name, col in cls.columns().items():
206+
for name, col in cls.columns().items(): # type: ignore[attr-defined]
203207
parts.append(textwrap.indent(f'- "{name}": {col!r}', prefix=" " * 4))
204-
if validation_rules := cls._schema_validation_rules():
208+
if validation_rules := cls._schema_validation_rules(): # type: ignore[attr-defined]
205209
parts.append(textwrap.indent("Rules:", prefix=" " * 2))
206210
for name, rule in validation_rules.items():
207211
parts.append(textwrap.indent(f'- "{name}": {rule!r}', prefix=" " * 4))

dataframely/_pydantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
def _dict_to_df(schema_type: type[BaseSchema], data: dict) -> pl.DataFrame:
2121
return pl.from_dict(
2222
data,
23-
schema=schema_type.to_polars_schema(),
23+
schema=schema_type.to_polars_schema(), # type: ignore[attr-defined]
2424
)
2525

2626

2727
def _validate_df_schema(schema_type: type[_S], df: pl.DataFrame) -> DataFrame[_S]:
28-
if not schema_type.is_valid(df):
28+
if not schema_type.is_valid(df): # type: ignore[attr-defined]
2929
raise ValueError("DataFrame violates schema")
3030
return df # type: ignore
3131

dataframely/filter_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _sink(
319319
backend.sink_failure_info(
320320
lf=self._lf,
321321
serialized_rules=json.dumps(self._rule_columns),
322-
serialized_schema=self.schema.serialize(),
322+
serialized_schema=self.schema.serialize(), # type: ignore[attr-defined]
323323
**kwargs,
324324
)
325325

@@ -333,7 +333,7 @@ def _write(
333333
backend.write_failure_info(
334334
df=self._df,
335335
serialized_rules=json.dumps(self._rule_columns),
336-
serialized_schema=self.schema.serialize(),
336+
serialized_schema=self.schema.serialize(), # type: ignore[attr-defined]
337337
**kwargs,
338338
)
339339

tests/test_typing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,16 @@ def test_collection_concat() -> None:
8080
c1 = MyCollection.create_empty()
8181
c2 = MyCollection.create_empty()
8282
dy.concat_collection_members([c1, c2])
83+
84+
85+
# ------------------------------------------------------------------------------------ #
86+
# ATTRIBUTE ACCESS #
87+
# ------------------------------------------------------------------------------------ #
88+
89+
90+
def test_non_existent_column_access() -> None:
91+
Schema.non_existing_col # type: ignore[attr-defined]
92+
93+
94+
def test_valid_column_access() -> None:
95+
Schema.a # Should pass type checking

0 commit comments

Comments
 (0)