Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit deb3a33

Browse files
wenleixfacebook-github-bot
authored andcommitted
Remove DType.fields (#462)
Summary: Pull Request resolved: #462 It's only available for Struct type. Original motivation is to avoid some pyre warning, looks like we can get the same outcome by adding explicit `cast` or more accurate type annotation. Reviewed By: dracifer Differential Revision: D38266758 fbshipit-source-id: db1587d69921746d8a32fda9bd98f639bc257e2f
1 parent bf43812 commit deb3a33

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

torcharrow/_interop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _dtype_to_arrowtype(t: dt.DType) -> pa.DataType:
229229
pa.field(
230230
f.name, _dtype_to_arrowtype(f.dtype), nullable=f.dtype.nullable
231231
)
232-
for f in t.fields
232+
for f in cast(dt.Struct, t).fields
233233
]
234234
)
235235
raise NotImplementedError(f"Unsupported DType to Arrow type: {str(t)}")

torcharrow/dtypes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def __str__(self):
5454

5555
@dataclass(frozen=True) # type: ignore
5656
class DType(ABC):
57-
fields: ty.ClassVar[ty.Any] = NotImplementedError
58-
5957
typecode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__"
6058
arraycode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__"
6159

@@ -573,7 +571,7 @@ def contains_tuple(t: DType):
573571
# pyre-fixme[16]: `DType` has no attribute `key_dtype`.
574572
return contains_tuple(t.key_dtype) or contains_tuple(t.item_dtype)
575573
if is_struct(t):
576-
return any(contains_tuple(f.dtype) for f in t.fields)
574+
return any(contains_tuple(f.dtype) for f in ty.cast(Struct, t).fields)
577575

578576
return False
579577

@@ -708,9 +706,13 @@ def common_dtype(l: DType, r: DType) -> ty.Optional[DType]:
708706
return String(l.nullable or r.nullable)
709707
if is_boolean_or_numerical(l) and is_boolean_or_numerical(r):
710708
return promote(l, r)
711-
if is_tuple(l) and is_tuple(r) and len(l.fields) == len(r.fields):
709+
if (
710+
is_tuple(l)
711+
and is_tuple(r)
712+
and len(ty.cast(Struct, l).fields) == len(ty.cast(Struct, r).fields)
713+
):
712714
res = []
713-
for i, j in zip(l.fields, r.fields):
715+
for i, j in zip(ty.cast(Struct, l).fields, ty.cast(Struct, r).fields):
714716
m = common_dtype(i, j)
715717
if m is None:
716718
return None

torcharrow/idataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import (
1212
Any,
1313
Callable,
14+
cast,
1415
Dict,
1516
get_type_hints,
1617
Iterable,
@@ -171,6 +172,11 @@ def columns(self):
171172
"""The column labels of the DataFrame."""
172173
return [f.name for f in self.dtype.fields]
173174

175+
@property
176+
@traceproperty
177+
def dtype(self) -> dt.Struct:
178+
return cast(dt.Struct, self._dtype)
179+
174180
def __contains__(self, key: str) -> bool:
175181
for f in self.dtype.fields:
176182
if key == f.name:

torcharrow/velox_rt/dataframe_cpu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def __setitem__(self, name: str, value: Any) -> None:
329329
empty_df = len(self.dtype.fields) == 0
330330

331331
# Update dtype
332-
# pyre-fixme[16]: `DType` has no attribute `get_index`.
333332
idx = self.dtype.get_index(name)
334333
if idx is None:
335334
# append column
@@ -354,7 +353,7 @@ def _set_field_data(self, name: str, col: Column, empty_df: bool):
354353
new_delegate.set_length(len(col._data))
355354

356355
# Set columns for new_delegate
357-
for idx in range(len(self._dtype.fields)):
356+
for idx in range(len(self.dtype.fields)):
358357
if idx != column_idx:
359358
new_delegate.set_child(idx, self._data.child_at(idx))
360359
else:
@@ -1477,6 +1476,7 @@ def __pos__(self):
14771476

14781477
def log(self) -> DataFrameCpu:
14791478
return self._fromdata(
1479+
# pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]`
14801480
{
14811481
self.dtype.fields[i]
14821482
# pyre-fixme[16]: `Column` has no attribute `log`.
@@ -1498,6 +1498,7 @@ def log(self) -> DataFrameCpu:
14981498
def isin(self, values: Union[list, dict, Column]):
14991499
if isinstance(values, list):
15001500
return self._fromdata(
1501+
# pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]`
15011502
{
15021503
self.dtype.fields[i]
15031504
.name: ColumnCpuMixin._from_velox(
@@ -1538,6 +1539,7 @@ def fill_null(self, fill_value: Optional[Union[dt.ScalarTypes, Dict]]):
15381539
return self
15391540
if isinstance(fill_value, Column._scalar_types):
15401541
return self._fromdata(
1542+
# pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]`
15411543
{
15421544
self.dtype.fields[i]
15431545
.name: ColumnCpuMixin._from_velox(
@@ -1844,6 +1846,7 @@ def drop(self, columns: Union[str, List[str]]):
18441846
columns = [columns]
18451847
self._check_columns(columns)
18461848
return self._fromdata(
1849+
# pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]`
18471850
{
18481851
self.dtype.fields[i].name: ColumnCpuMixin._from_velox(
18491852
self.device,
@@ -1865,6 +1868,7 @@ def _keep(self, columns: List[str]):
18651868
"""
18661869
self._check_columns(columns)
18671870
return self._fromdata(
1871+
# pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]`
18681872
{
18691873
self.dtype.fields[i].name: ColumnCpuMixin._from_velox(
18701874
self.device,
@@ -2173,7 +2177,6 @@ def groupby(
21732177
key_fields = []
21742178
item_fields = []
21752179
for k in key_columns:
2176-
# pyre-fixme[16]: `DType` has no attribute `get`.
21772180
key_fields.append(dt.Field(k, self.dtype.get(k)))
21782181
for f in self.dtype.fields:
21792182
if f.name not in key_columns:

0 commit comments

Comments
 (0)