@@ -413,15 +413,28 @@ def concat(datasets: List[Dataset]) -> Dataset[R]:
413413 '''
414414 from_concat_mapping = Dataset .create_from_concat_mapping (datasets )
415415
416- def get_item (dataframe , index ):
417- dataset_index , inner_index = from_concat_mapping (index )
418- return datasets [dataset_index ][inner_index ]
416+ if any ([dataset .dataframe is None for dataset in datasets ]):
419417
420- return Dataset (
421- dataframe = None , # TODO: concat dataframes?
422- length = sum (map (len , datasets )),
423- get_item = get_item ,
424- )
418+ def get_item (dataframe , index ):
419+ dataset_index , inner_index = from_concat_mapping (index )
420+ return datasets [dataset_index ][inner_index ]
421+
422+ return Dataset (
423+ dataframe = None ,
424+ length = sum (map (len , datasets )),
425+ get_item = get_item ,
426+ )
427+ else :
428+
429+ def get_item (dataframe , index ):
430+ dataset_index , _ = from_concat_mapping (index )
431+ return datasets [dataset_index ].get_item (dataframe , index )
432+
433+ return Dataset (
434+ dataframe = pd .concat ([dataset .dataframe for dataset in datasets ]),
435+ length = sum (map (len , datasets )),
436+ get_item = get_item ,
437+ )
425438
426439 @staticmethod
427440 def create_from_combine_mapping (datasets ):
@@ -600,6 +613,28 @@ def test_concat_dataset():
600613 assert dataset [6 ] == 1
601614
602615
616+ def test_concat_heterogenous_datasets ():
617+ dataset1 = Dataset .from_dataframe (
618+ pd .DataFrame (dict (a = [1 ], b = ['a' ])).set_index ('a' ),
619+ )
620+ dataset2 = Dataset .from_dataframe (
621+ pd .DataFrame (dict (a = [1 ], b = [1 ], c = [2 ])).set_index ('a' ),
622+ )
623+ dataset = (
624+ Dataset .concat ([dataset1 , dataset2 ])
625+ .map (lambda row : row ['b' ])
626+ )
627+
628+ assert list (dataset ) == ['a' , 1 ]
629+
630+ dataset_other_functions = Dataset .concat ([
631+ dataset1 .map (lambda row : row ['b' ]),
632+ dataset2 .map (lambda row : row ['c' ]),
633+ ])
634+
635+ assert list (dataset_other_functions ) == ['a' , 2 ]
636+
637+
603638def test_zip_dataset ():
604639 dataset = Dataset .zip ([
605640 Dataset .from_subscriptable (list (range (5 ))),
0 commit comments