Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit a700ae1

Browse files
waitingkuowenleix
authored andcommitted
Enable single column str for drop/drop_duplicates/groupby
Also update _check_columns to ensure the input columns is always a sequence of str
1 parent 41eba1b commit a700ae1

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

torcharrow/idataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def any(self):
425425
# column alnternating
426426
@trace
427427
@expression
428-
def drop(self, columns: List[str]):
428+
def drop(self, columns: Union[str, List[str]]):
429429
"""
430430
Returns DataFrame without the removed columns.
431431
"""

torcharrow/test/test_dataframe.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,24 @@ def base_test_describe_dataframe(self):
833833
],
834834
)
835835

836+
def base_test_drop_by_str_as_columns(self):
837+
df = ta.dataframe(device=self.device)
838+
df["aa"] = [1, 2, 3]
839+
df["ab"] = [11, 22, 33]
840+
df["ac"] = [111, 222, 333]
841+
self.assertEqual(list(df.drop("aa")), [(11, 111), (22, 222), (33, 333)])
842+
self.assertEqual(list(df.drop("ab")), [(1, 111), (2, 222), (3, 333)])
843+
self.assertEqual(list(df.drop("ac")), [(1, 11), (2, 22), (3, 33)])
844+
845+
def base_test_drop_by_list_of_str_as_columns(self):
846+
df = ta.dataframe(device=self.device)
847+
df["aa"] = [1, 2, 3]
848+
df["ab"] = [11, 22, 33]
849+
df["ac"] = [111, 222, 333]
850+
self.assertEqual(list(df.drop(["aa", "ab"])), [(111,), (222,), (333,)])
851+
self.assertEqual(list(df.drop(["aa", "ac"])), [(11,), (22,), (33,)])
852+
self.assertEqual(list(df.drop(["ab", "ac"])), [(1,), (2,), (3,)])
853+
836854
def base_test_drop_keep_rename_reorder_pipe(self):
837855
df = ta.dataframe(device=self.device)
838856
df["a"] = [1, 2, 3]
@@ -895,6 +913,21 @@ def base_test_locals_and_me_equivalence(self):
895913
)
896914
self.assertEqual(list(df.select("*", d=me["a"] + me["b"])), list(gf))
897915

916+
917+
def base_test_groupby_str(self):
918+
df = ta.dataframe(
919+
{"a": [1, 1, 2], "b": [1, 2, 3], "c": [2, 2, 1]}, device=self.device
920+
)
921+
self.assertEqual(list(df.groupby("a").size), [(1, 2), (2, 1)])
922+
923+
924+
def base_test_groupby_list_of_str(self):
925+
df = ta.dataframe(
926+
{"a": [1, 1, 2], "b": [1, 2, 3], "c": [2, 2, 1]}, device=self.device
927+
)
928+
self.assertEqual(list(df.groupby(["a"]).size), [(1, 2), (2, 1)])
929+
930+
898931
def base_test_groupby_size_pipe(self):
899932
df = ta.dataframe(
900933
{"a": [1, 1, 2], "b": [1, 2, 3], "c": [2, 2, 1]}, device=self.device

torcharrow/test/test_dataframe_cpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def test_isin2(self):
6464
def test_describe_dataframe(self):
6565
return self.base_test_describe_dataframe()
6666

67+
def test_drop_by_str_as_columns(self):
68+
return self.base_test_drop_by_str_as_columns()
69+
70+
def test_drop_by_list_of_str_as_columns(self):
71+
return self.base_test_drop_by_list_of_str_as_columns()
72+
6773
def test_drop_keep_rename_reorder_pipe(self):
6874
return self.base_test_drop_keep_rename_reorder_pipe()
6975

@@ -73,6 +79,12 @@ def test_me_on_str(self):
7379
def test_locals_and_me_equivalence(self):
7480
return self.base_test_locals_and_me_equivalence()
7581

82+
def test_groupby_str(self):
83+
return self.base_test_groupby_str()
84+
85+
def test_groupby_list_of_str(self):
86+
return self.base_test_groupby_list_of_str()
87+
7688
def test_groupby_size_pipe(self):
7789
return self.base_test_groupby_size_pipe()
7890

torcharrow/velox_rt/dataframe_cpu.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ def append(self, values: Iterable[Union[None, dict, tuple]]):
308308
return self
309309

310310
def _check_columns(self, columns: Iterable[str]):
311+
if isinstance(columns, str):
312+
raise TypeError(f"columns should be Iterable of str but not str")
313+
311314
valid_names = {f.name for f in self.dtype.fields}
312315
for n in columns:
313316
if n not in valid_names:
@@ -1597,11 +1600,13 @@ def drop_null(self, how="any"):
15971600
@expression
15981601
def drop_duplicates(
15991602
self,
1600-
subset: Optional[List[str]] = None,
1603+
subset: Optional[Union[str, List[str]]] = None,
16011604
keep="first",
16021605
):
16031606
self._prototype_support_warning("drop_duplicates")
16041607

1608+
if isinstance(subset, str):
1609+
subset = [subset]
16051610
columns = subset if subset is not None else self.columns
16061611
self._check_columns(columns)
16071612

@@ -1857,7 +1862,9 @@ def describe(
18571862

18581863
@trace
18591864
@expression
1860-
def drop(self, columns: List[str]):
1865+
def drop(self, columns: Union[str, List[str]]):
1866+
if isinstance(columns, str):
1867+
columns = [columns]
18611868
self._check_columns(columns)
18621869
return self._fromdata(
18631870
{
@@ -2105,7 +2112,7 @@ def pipe(self, func, *args, **kwargs):
21052112
@expression
21062113
def groupby(
21072114
self,
2108-
by: List[str],
2115+
by: Union[str, List[str]],
21092116
sort=False,
21102117
drop_null=True,
21112118
):
@@ -2181,6 +2188,8 @@ def groupby(
21812188
# TODO implement
21822189
assert not sort
21832190
assert drop_null
2191+
if isinstance(by, str):
2192+
by = [by]
21842193
self._check_columns(by)
21852194

21862195
key_columns = by

0 commit comments

Comments
 (0)