@@ -743,29 +743,40 @@ def test_get_file_spec(self):
743743 "dummy_dataset_with_configs/plus1/0.0.1/dummy_dataset_with_configs-test.tfrecord@1" ,
744744 )
745745
746- def test_load_as_data_source (self ):
746+ @parameterized .parameters (
747+ (
748+ file_adapters .FileFormat .ARRAY_RECORD ,
749+ array_record .ArrayRecordDataSource ,
750+ ),
751+ )
752+ def test_load_as_data_source (self , file_format , data_source_type ):
747753 data_dir = self .get_temp_dir ()
748754 builder = DummyDatasetWithConfigs (
749755 data_dir = data_dir ,
750756 config = "plus1" ,
751- file_format = file_adapters . FileFormat . ARRAY_RECORD ,
757+ file_format = file_format ,
752758 )
753759 builder .download_and_prepare ()
754760
755761 data_source = builder .as_data_source ()
756762 assert isinstance (data_source , dict )
757- assert isinstance (data_source ["train" ], array_record . ArrayRecordDataSource )
758- assert isinstance (data_source ["test" ], array_record . ArrayRecordDataSource )
763+ assert isinstance (data_source ["train" ], data_source_type )
764+ assert isinstance (data_source ["test" ], data_source_type )
759765 assert len (data_source ["test" ]) == 10
760766 assert data_source ["test" ][0 ]["x" ] == 28
761767 assert len (data_source ["train" ]) == 20
762768 assert data_source ["train" ][0 ]["x" ] == 7
763769
764770 data_source = builder .as_data_source (split = "test" )
765- assert isinstance (data_source , array_record . ArrayRecordDataSource )
771+ assert isinstance (data_source , data_source_type )
766772 assert len (data_source ) == 10
767773 assert data_source [0 ]["x" ] == 28
768774
775+ data_source = builder .as_data_source (split = "all" )
776+ assert isinstance (data_source , data_source_type )
777+ assert len (data_source ) == 30
778+ assert data_source [0 ]["x" ] == 7
779+
769780 def test_load_as_data_source_alternative_file_format (self ):
770781 data_dir = self .get_temp_dir ()
771782 builder = DummyDatasetWithConfigs (
0 commit comments