@@ -108,6 +108,29 @@ def complex_dataframe():
108108 'feat2' : [1 , 2 , 3 , 2 , 3 , 4 ]})
109109
110110
111+ @pytest .fixture
112+ def multiindex_dataframe ():
113+ """Example MultiIndex DataFrame, taken from pandas documentation
114+ """
115+ iterables = [['bar' , 'baz' , 'foo' , 'qux' ], ['one' , 'two' ]]
116+ index = pd .MultiIndex .from_product (iterables , names = ['first' , 'second' ])
117+ df = pd .DataFrame (np .random .randn (10 , 8 ), columns = index )
118+ return df
119+
120+
121+ @pytest .fixture
122+ def multiindex_dataframe_incomplete (multiindex_dataframe ):
123+ """Example MultiIndex DataFrame with missing entries
124+ """
125+ df = multiindex_dataframe
126+ mask_array = np .zeros (df .size )
127+ mask_array [:20 ] = 1
128+ np .random .shuffle (mask_array )
129+ mask = mask_array .reshape (df .shape ).astype (bool )
130+ df .mask (mask , inplace = True )
131+ return df
132+
133+
111134def test_transformed_names_simple (simple_dataframe ):
112135 """
113136 Get transformed names of features in `transformed_names` attribute
@@ -234,6 +257,33 @@ def test_complex_df(complex_dataframe):
234257 assert len (transformed [c ]) == len (df [c ])
235258
236259
260+ def test_numeric_column_names (complex_dataframe ):
261+ """
262+ Get a dataframe from a complex mapped dataframe with numeric column names
263+ """
264+ df = complex_dataframe
265+ df .columns = [0 , 1 , 2 ]
266+ mapper = DataFrameMapper (
267+ [(0 , None ), (1 , None ), (2 , None )], df_out = True )
268+ transformed = mapper .fit_transform (df )
269+ assert len (transformed ) == len (complex_dataframe )
270+ for c in df .columns :
271+ assert len (transformed [c ]) == len (df [c ])
272+
273+
274+ def test_multiindex_df (multiindex_dataframe_incomplete ):
275+ """
276+ Get a dataframe from a multiindex dataframe with missing data
277+ """
278+ df = multiindex_dataframe_incomplete
279+ mapper = DataFrameMapper ([([c ], Imputer ()) for c in df .columns ],
280+ df_out = True )
281+ transformed = mapper .fit_transform (df )
282+ assert len (transformed ) == len (multiindex_dataframe_incomplete )
283+ for c in df .columns :
284+ assert len (transformed [str (c )]) == len (df [c ])
285+
286+
237287def test_binarizer_df ():
238288 """
239289 Check level names from LabelBinarizer
0 commit comments