@@ -384,7 +384,16 @@ def get_item(dataframe, index):
384384 + '' .join ([random .choice (string .ascii_lowercase ) for _ in range (8 )])
385385 )
386386
387- new_dataframe = pd .concat ([dataset .dataframe for dataset in datasets ])
387+ dataframes = [dataset .dataframe for dataset in datasets ]
388+ for dataframe in dataframes :
389+ for col in dataframe .columns :
390+ if (
391+ dataframe [col ].dtype == int
392+ and any ([col not in other .columns for other in dataframes ])
393+ ):
394+ dataframe [col ] = dataframe [col ].astype (object )
395+
396+ new_dataframe = pd .concat (dataframes )
388397 new_dataframe [dataset_column ] = [
389398 from_concat_mapping (index )[0 ]
390399 for index in range (len (new_dataframe ))
@@ -860,3 +869,19 @@ def test_update_stratified_split():
860869 )
861870
862871 filepath .unlink ()
872+
873+
874+ def test_concat_missing_columns ():
875+ dataset1 = Dataset .from_dataframe (
876+ pd .DataFrame (dict (a = [1 , 2 , 3 ], b = ['a' , 'b' , 'c' ]))
877+ )
878+ dataset2 = Dataset .from_dataframe (
879+ pd .DataFrame (dict (c = [True , False ], d = [[1 , 2 ], [3 , 4 ]]))
880+ )
881+ concatenated = Dataset .concat ([dataset1 , dataset2 ])
882+
883+ assert type (concatenated [0 ]['a' ]) == int
884+ assert type (concatenated [- 1 ]['a' ]) == float
885+ assert type (concatenated [0 ]['b' ]) == str
886+ assert type (concatenated [- 1 ]['c' ]) == bool
887+ assert type (concatenated [- 1 ]['d' ]) == list
0 commit comments