Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit b93b2c8

Browse files
andrewaikens87facebook-github-bot
authored andcommitted
Improve icolumn.py cc 1/4 test_numerical_column (#407)
Summary: Pull Request resolved: #407 Improves test coverage for icolumn.py through test_numerical_column. This stack gets icolumn.py to 86% coverage. Reviewed By: OswinC Differential Revision: D37492648 fbshipit-source-id: a90267d2372677722a8cdcae55d43fd146c4b9b3
1 parent 4f8297a commit b93b2c8

File tree

3 files changed

+104
-7
lines changed

3 files changed

+104
-7
lines changed

torcharrow/icolumn.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,21 @@ def __len__(self):
269269
# printing ----------------------------------------------------------------
270270

271271
def __str__(self):
272-
return f"Column([{', '.join(str(i) for i in self)}], id = {self.id})"
272+
item_padding = "'" if dt.is_string(self.dtype) else ""
273+
return f"Column([{', '.join(f'{item_padding}{i}{item_padding}' for i in self)}], id = {self.id})"
273274

274275
def __repr__(self):
275-
rows = [[l if l is not None else "None"] for l in self]
276+
item_padding = "'" if dt.is_string(self.dtype) else ""
277+
rows = [
278+
[f"{item_padding}{l}{item_padding}" if l is not None else "None"]
279+
for l in self
280+
]
276281
tab = tabulate(
277282
rows,
278283
tablefmt="plain",
279284
showindex=True,
280285
)
281-
typ = (
282-
f"dtype: {self._dtype}, length: {len(self)}, null_count: {self.null_count}"
283-
)
286+
typ = f"dtype: {self._dtype}, length: {len(self)}, null_count: {self.null_count}, device: {self.device}"
284287
return tab + dt.NL + typ
285288

286289
# selectors/getters -------------------------------------------------------

torcharrow/test/test_numerical_column.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
import statistics
89
import typing as ty
910
import unittest
11+
from collections import defaultdict
1012
from math import ceil, floor, isnan, log
1113

1214
import numpy as np
1315
import numpy.testing
16+
import pandas as pd
1417
import torcharrow as ta
1518
import torcharrow.dtypes as dt
1619
from torcharrow.icolumn import Column
@@ -32,7 +35,7 @@ def base_test_empty(self):
3235
return empty_i64_column
3336

3437
def base_test_full(self):
35-
col = ta.column([i for i in range(4)], dtype=dt.int64, device=self.device)
38+
col = ta.column(list(range(4)), dtype=dt.int64, device=self.device)
3639

3740
# self.assertEqual(col._offset, 0)
3841
self.assertEqual(len(col), 4)
@@ -43,7 +46,7 @@ def base_test_full(self):
4346
return col
4447

4548
def base_test_is_immutable(self):
46-
col = ta.column([i for i in range(4)], dtype=dt.int64, device=self.device)
49+
col = ta.column(list(range(4)), dtype=dt.int64, device=self.device)
4750
with self.assertRaises(AttributeError):
4851
# AssertionError: can't append a finalized list
4952
col._append(None)
@@ -164,6 +167,13 @@ def base_test_map_where_filter(self):
164167
# Values that are not found in the dict are converted to None
165168
self.assertEqual(list(col.map({3: 33})), [None, None, None, 33, None, None])
166169

170+
# maps default dict
171+
d_dict = defaultdict(lambda: 1, {None: 2})
172+
self.assertEqual(
173+
list(col.map(arg=d_dict)),
174+
[2, 2, 2, 1, 1, 1],
175+
)
176+
167177
# maps None
168178
self.assertEqual(
169179
list(col.map({None: 1, 3: 33})),
@@ -196,6 +206,18 @@ def base_test_map_where_filter(self):
196206
# filter
197207
self.assertEqual(list(col.filter([True, False] * 3)), [None, None, 4])
198208

209+
with self.assertRaisesRegex(
210+
expected_exception=TypeError,
211+
expected_regex="columns parameter for flat columns not supported",
212+
):
213+
col.filter([True, False], columns=["test", "test2"])
214+
215+
with self.assertRaisesRegex(
216+
expected_exception=TypeError,
217+
expected_regex="predicate must be a unary boolean predicate or iterable of booleans",
218+
):
219+
col.filter(123)
220+
199221
@staticmethod
200222
def _accumulate(col, val):
201223
if len(col) == 0:
@@ -217,6 +239,26 @@ def base_test_reduce(self):
217239
)
218240
self.assertEqual(list(d), [1, 3, 6])
219241

242+
col_no_init = c.reduce(
243+
fun=operator.add,
244+
)
245+
self.assertEqual(sum(c), col_no_init)
246+
247+
c_empty = ta.column(dtype=dt.int64, device=self.device)
248+
result = c_empty.reduce(
249+
fun=TestNumericalColumn._accumulate,
250+
initializer=c,
251+
)
252+
self.assertTrue(all(c == result))
253+
254+
with self.assertRaisesRegex(
255+
expected_exception=TypeError,
256+
expected_regex="reduce of empty sequence with no initial value",
257+
):
258+
c_empty.reduce(
259+
fun=TestNumericalColumn._accumulate,
260+
)
261+
220262
def base_test_sort_stuff(self):
221263
col = ta.column([2, 1, 3], device=self.device)
222264

@@ -795,6 +837,46 @@ def base_test_batch_collate(self):
795837
it = c.batch(2)
796838
self.assertEqual(list(Column.unbatch(it)), [1, 2, 3, 4, 5, 6, 7])
797839

840+
def base_test_str(self):
841+
c = ta.column(list(range(5)), device=self.device)
842+
c.id = 123
843+
844+
expected = "Column([0, 1, 2, 3, 4], id = 123)"
845+
self.assertEqual(expected, str(c))
846+
847+
def base_test_repr(self):
848+
c = ta.column(list(range(5)), device=self.device)
849+
expected_repr = (
850+
"0 0\n"
851+
"1 1\n"
852+
"2 2\n"
853+
"3 3\n"
854+
"4 4\n"
855+
f"dtype: int64, length: 5, null_count: 0, device: {self.device}"
856+
)
857+
858+
self.assertEqual(expected_repr, repr(c))
859+
860+
def base_test_to_pandas(self):
861+
c_repr = list(range(10))
862+
c = ta.column(c_repr, device=self.device)
863+
expected = pd.Series(c_repr)
864+
self.assertTrue(all(expected == c.to_pandas()))
865+
866+
def base_test_transform(self):
867+
c_repr = list(range(10))
868+
c = ta.column(c_repr, device=self.device)
869+
870+
result = c.transform(lambda x: x * 10)
871+
872+
self.assertEqual([x * 10 for x in c_repr], list(result))
873+
874+
with self.assertRaisesRegex(
875+
expected_exception=TypeError,
876+
expected_regex="columns parameter for flat columns not supported",
877+
):
878+
c.transform(lambda x: x * 10, columns=["test"])
879+
798880

799881
if __name__ == "__main__":
800882

torcharrow/test/test_numerical_column_cpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ def test_batch_collate(self):
8383
def test_cast(self):
8484
return self.base_test_cast()
8585

86+
def test_str(self):
87+
self.base_test_str()
88+
89+
def test_repr(self):
90+
self.base_test_repr()
91+
92+
def test_to_pandas(self):
93+
self.base_test_to_pandas()
94+
95+
def test_transform(self):
96+
self.base_test_transform()
97+
8698

8799
if __name__ == "__main__":
88100
unittest.main()

0 commit comments

Comments
 (0)