@@ -268,6 +268,57 @@ def test_add_null(self):
268268 with self .assertRaises (RuntimeError ) as RE :
269269 ds .add_field ('test' , [])
270270
271+ def test_concat (self ):
272+ """
273+ 测试两个dataset能否正确concat
274+
275+ """
276+ ds1 = DataSet ({"x" : [[1 , 2 , 3 , 4 ] for i in range (10 )], "y" : [[5 , 6 ] for i in range (10 )]})
277+ ds2 = DataSet ({"x" : [[4 ,3 ,2 ,1 ] for i in range (10 )], "y" : [[6 ,5 ] for i in range (10 )]})
278+ ds3 = ds1 .concat (ds2 )
279+
280+ self .assertEqual (len (ds3 ), 20 )
281+
282+ self .assertListEqual (ds1 [9 ]['x' ], [1 , 2 , 3 , 4 ])
283+ self .assertListEqual (ds1 [10 ]['x' ], [4 ,3 ,2 ,1 ])
284+
285+ ds2 [0 ]['x' ][0 ] = 100
286+ self .assertEqual (ds3 [10 ]['x' ][0 ], 4 ) # 不改变copy后的field了
287+
288+ ds3 [10 ]['x' ][0 ] = - 100
289+ self .assertEqual (ds2 [0 ]['x' ][0 ], 100 ) # 不改变copy前的field了
290+
291+ # 测试inplace
292+ ds1 = DataSet ({"x" : [[1 , 2 , 3 , 4 ] for i in range (10 )], "y" : [[5 , 6 ] for i in range (10 )]})
293+ ds2 = DataSet ({"x" : [[4 , 3 , 2 , 1 ] for i in range (10 )], "y" : [[6 , 5 ] for i in range (10 )]})
294+ ds3 = ds1 .concat (ds2 , inplace = True )
295+
296+ ds2 [0 ]['x' ][0 ] = 100
297+ self .assertEqual (ds3 [10 ]['x' ][0 ], 4 ) # 不改变copy后的field了
298+
299+ ds3 [10 ]['x' ][0 ] = - 100
300+ self .assertEqual (ds2 [0 ]['x' ][0 ], 100 ) # 不改变copy前的field了
301+
302+ ds3 [0 ]['x' ][0 ] = 100
303+ self .assertEqual (ds1 [0 ]['x' ][0 ], 100 ) # 改变copy前的field了
304+
305+ # 测试mapping
306+ ds1 = DataSet ({"x" : [[1 , 2 , 3 , 4 ] for i in range (10 )], "y" : [[5 , 6 ] for i in range (10 )]})
307+ ds2 = DataSet ({"X" : [[4 , 3 , 2 , 1 ] for i in range (10 )], "Y" : [[6 , 5 ] for i in range (10 )]})
308+ ds3 = ds1 .concat (ds2 , field_mapping = {'X' :'x' , 'Y' :'y' })
309+ self .assertEqual (len (ds3 ), 20 )
310+
311+ # 测试忽略掉多余的
312+ ds1 = DataSet ({"x" : [[1 , 2 , 3 , 4 ] for i in range (10 )], "y" : [[5 , 6 ] for i in range (10 )]})
313+ ds2 = DataSet ({"X" : [[4 , 3 , 2 , 1 ] for i in range (10 )], "Y" : [[6 , 5 ] for i in range (10 )], 'Z' :[0 ]* 10 })
314+ ds3 = ds1 .concat (ds2 , field_mapping = {'X' :'x' , 'Y' :'y' })
315+
316+ # 测试报错
317+ ds1 = DataSet ({"x" : [[1 , 2 , 3 , 4 ] for i in range (10 )], "y" : [[5 , 6 ] for i in range (10 )]})
318+ ds2 = DataSet ({"X" : [[4 , 3 , 2 , 1 ] for i in range (10 )]})
319+ with self .assertRaises (RuntimeError ):
320+ ds3 = ds1 .concat (ds2 , field_mapping = {'X' :'x' })
321+
271322
272323class TestDataSetIter (unittest .TestCase ):
273324 def test__repr__ (self ):
0 commit comments