Skip to content

Commit 50d7cfd

Browse files
committed
在dataset中添加concat函数,支持将两个dataset concat起来
1 parent d6072ba commit 50d7cfd

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

fastNLP/core/dataset.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,11 @@ def print_field_meta(self):
531531
| pad_value | 0 | |
532532
+-------------+-------+-------+
533533
534-
:param field_names: DataSet中field的名称
535-
:param is_input: field是否为input
536-
:param is_target: field是否为target
537-
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
538-
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义
534+
str field_names: DataSet中field的名称
535+
bool is_input: field是否为input
536+
bool is_target: field是否为target
537+
bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
538+
int pad_value: 该field的pad的值,仅在该field为input或target时有意义
539539
:return:
540540
"""
541541
if len(self.field_arrays)>0:
@@ -1146,3 +1146,40 @@ def delete_collate_fn(self, name=None):
11461146

11471147
def _collate_batch(self, ins_list):
11481148
return self.collater.collate_batch(ins_list)
1149+
1150+
def concat(self, dataset, inplace=True, field_mapping=None):
1151+
"""
1152+
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target
1153+
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
1154+
当前dataset含有field,则会报错。
1155+
1156+
:param DataSet, dataset: 需要和当前dataset concat的dataset
1157+
:param bool, inplace: 是否直接将dataset组合到当前dataset中
1158+
:param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field
1159+
名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称
1160+
1161+
:return: DataSet
1162+
"""
1163+
assert isinstance(dataset, DataSet), "Can only concat two datasets."
1164+
1165+
fns_in_this_dataset = set(self.get_field_names())
1166+
fns_in_other_dataset = dataset.get_field_names()
1167+
reverse_field_mapping = {}
1168+
if field_mapping is not None:
1169+
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset]
1170+
reverse_field_mapping = {v:k for k, v in field_mapping.items()}
1171+
fns_in_other_dataset = set(fns_in_other_dataset)
1172+
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset)
1173+
1174+
if fn_not_seen:
1175+
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}")
1176+
1177+
if inplace:
1178+
ds = self
1179+
else:
1180+
ds = deepcopy(self)
1181+
1182+
for fn in fns_in_this_dataset:
1183+
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content))
1184+
1185+
return ds

tests/core/test_dataset.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,57 @@ def test_add_null(self):
268268
with self.assertRaises(RuntimeError) as RE:
269269
ds.add_field('test', [])
270270

271+
def test_concat(self):
272+
"""
273+
测试两个dataset能否正确concat
274+
275+
"""
276+
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
277+
ds2 = DataSet({"x": [[4,3,2,1] for i in range(10)], "y": [[6,5] for i in range(10)]})
278+
ds3 = ds1.concat(ds2)
279+
280+
self.assertEqual(len(ds3), 20)
281+
282+
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
283+
self.assertListEqual(ds1[10]['x'], [4,3,2,1])
284+
285+
ds2[0]['x'][0] = 100
286+
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
287+
288+
ds3[10]['x'][0] = -100
289+
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
290+
291+
# 测试inplace
292+
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
293+
ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]})
294+
ds3 = ds1.concat(ds2, inplace=True)
295+
296+
ds2[0]['x'][0] = 100
297+
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
298+
299+
ds3[10]['x'][0] = -100
300+
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
301+
302+
ds3[0]['x'][0] = 100
303+
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
304+
305+
# 测试mapping
306+
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
307+
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
308+
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'})
309+
self.assertEqual(len(ds3), 20)
310+
311+
# 测试忽略掉多余的
312+
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
313+
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z':[0]*10})
314+
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'})
315+
316+
# 测试报错
317+
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
318+
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
319+
with self.assertRaises(RuntimeError):
320+
ds3 = ds1.concat(ds2, field_mapping={'X':'x'})
321+
271322

272323
class TestDataSetIter(unittest.TestCase):
273324
def test__repr__(self):

0 commit comments

Comments
 (0)