44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import operator
78import statistics
89import typing as ty
910import unittest
11+ from collections import defaultdict
1012from math import ceil , floor , isnan , log
1113
1214import numpy as np
1315import numpy .testing
16+ import pandas as pd
1417import torcharrow as ta
1518import torcharrow .dtypes as dt
1619from torcharrow .icolumn import Column
@@ -32,7 +35,7 @@ def base_test_empty(self):
3235 return empty_i64_column
3336
3437 def base_test_full (self ):
35- col = ta .column ([ i for i in range (4 )] , dtype = dt .int64 , device = self .device )
38+ col = ta .column (list ( range (4 )) , dtype = dt .int64 , device = self .device )
3639
3740 # self.assertEqual(col._offset, 0)
3841 self .assertEqual (len (col ), 4 )
@@ -43,7 +46,7 @@ def base_test_full(self):
4346 return col
4447
4548 def base_test_is_immutable (self ):
46- col = ta .column ([ i for i in range (4 )] , dtype = dt .int64 , device = self .device )
49+ col = ta .column (list ( range (4 )) , dtype = dt .int64 , device = self .device )
4750 with self .assertRaises (AttributeError ):
4851 # AssertionError: can't append a finalized list
4952 col ._append (None )
@@ -164,6 +167,13 @@ def base_test_map_where_filter(self):
164167 # Values that are not found in the dict are converted to None
165168 self .assertEqual (list (col .map ({3 : 33 })), [None , None , None , 33 , None , None ])
166169
170+ # maps default dict
171+ d_dict = defaultdict (lambda : 1 , {None : 2 })
172+ self .assertEqual (
173+ list (col .map (arg = d_dict )),
174+ [2 , 2 , 2 , 1 , 1 , 1 ],
175+ )
176+
167177 # maps None
168178 self .assertEqual (
169179 list (col .map ({None : 1 , 3 : 33 })),
@@ -196,6 +206,18 @@ def base_test_map_where_filter(self):
196206 # filter
197207 self .assertEqual (list (col .filter ([True , False ] * 3 )), [None , None , 4 ])
198208
209+ with self .assertRaisesRegex (
210+ expected_exception = TypeError ,
211+ expected_regex = "columns parameter for flat columns not supported" ,
212+ ):
213+ col .filter ([True , False ], columns = ["test" , "test2" ])
214+
215+ with self .assertRaisesRegex (
216+ expected_exception = TypeError ,
217+ expected_regex = "predicate must be a unary boolean predicate or iterable of booleans" ,
218+ ):
219+ col .filter (123 )
220+
199221 @staticmethod
200222 def _accumulate (col , val ):
201223 if len (col ) == 0 :
@@ -217,6 +239,26 @@ def base_test_reduce(self):
217239 )
218240 self .assertEqual (list (d ), [1 , 3 , 6 ])
219241
242+ col_no_init = c .reduce (
243+ fun = operator .add ,
244+ )
245+ self .assertEqual (sum (c ), col_no_init )
246+
247+ c_empty = ta .column (dtype = dt .int64 , device = self .device )
248+ result = c_empty .reduce (
249+ fun = TestNumericalColumn ._accumulate ,
250+ initializer = c ,
251+ )
252+ self .assertTrue (all (c == result ))
253+
254+ with self .assertRaisesRegex (
255+ expected_exception = TypeError ,
256+ expected_regex = "reduce of empty sequence with no initial value" ,
257+ ):
258+ c_empty .reduce (
259+ fun = TestNumericalColumn ._accumulate ,
260+ )
261+
220262 def base_test_sort_stuff (self ):
221263 col = ta .column ([2 , 1 , 3 ], device = self .device )
222264
@@ -795,6 +837,46 @@ def base_test_batch_collate(self):
795837 it = c .batch (2 )
796838 self .assertEqual (list (Column .unbatch (it )), [1 , 2 , 3 , 4 , 5 , 6 , 7 ])
797839
840+ def base_test_str (self ):
841+ c = ta .column (list (range (5 )), device = self .device )
842+ c .id = 123
843+
844+ expected = "Column([0, 1, 2, 3, 4], id = 123)"
845+ self .assertEqual (expected , str (c ))
846+
847+ def base_test_repr (self ):
848+ c = ta .column (list (range (5 )), device = self .device )
849+ expected_repr = (
850+ "0 0\n "
851+ "1 1\n "
852+ "2 2\n "
853+ "3 3\n "
854+ "4 4\n "
855+ f"dtype: int64, length: 5, null_count: 0, device: { self .device } "
856+ )
857+
858+ self .assertEqual (expected_repr , repr (c ))
859+
860+ def base_test_to_pandas (self ):
861+ c_repr = list (range (10 ))
862+ c = ta .column (c_repr , device = self .device )
863+ expected = pd .Series (c_repr )
864+ self .assertTrue (all (expected == c .to_pandas ()))
865+
866+ def base_test_transform (self ):
867+ c_repr = list (range (10 ))
868+ c = ta .column (c_repr , device = self .device )
869+
870+ result = c .transform (lambda x : x * 10 )
871+
872+ self .assertEqual ([x * 10 for x in c_repr ], list (result ))
873+
874+ with self .assertRaisesRegex (
875+ expected_exception = TypeError ,
876+ expected_regex = "columns parameter for flat columns not supported" ,
877+ ):
878+ c .transform (lambda x : x * 10 , columns = ["test" ])
879+
798880
799881if __name__ == "__main__" :
800882
0 commit comments