1515
1616"""Tests for all data sources."""
1717
18+ import pickle
1819from unittest import mock
1920
21+ import cloudpickle
2022from etils import epath
2123import pytest
2224import tensorflow_datasets as tfds
2325from tensorflow_datasets import testing
24- from tensorflow_datasets .core import dataset_builder
26+ from tensorflow_datasets .core import dataset_builder as dataset_builder_lib
2527from tensorflow_datasets .core import dataset_info as dataset_info_lib
2628from tensorflow_datasets .core import decode
2729from tensorflow_datasets .core import file_adapters
@@ -77,7 +79,7 @@ def mocked_parquet_dataset():
7779)
7880def test_read_write (
7981 tmp_path : epath .Path ,
80- builder_cls : dataset_builder .DatasetBuilder ,
82+ builder_cls : dataset_builder_lib .DatasetBuilder ,
8183 file_format : file_adapters .FileFormat ,
8284):
8385 builder = builder_cls (data_dir = tmp_path , file_format = file_format )
@@ -106,28 +108,36 @@ def test_read_write(
106108]
107109
108110
109- def create_dataset_info (file_format : file_adapters .FileFormat ):
111+ def create_dataset_builder (
112+ file_format : file_adapters .FileFormat ,
113+ ) -> dataset_builder_lib .DatasetBuilder :
110114 with mock .patch .object (splits_lib , 'SplitInfo' ) as split_mock :
111115 split_mock .return_value .name = 'train'
112116 split_mock .return_value .file_instructions = _FILE_INSTRUCTIONS
113117 dataset_info = mock .create_autospec (dataset_info_lib .DatasetInfo )
114118 dataset_info .file_format = file_format
115119 dataset_info .splits = {'train' : split_mock ()}
116120 dataset_info .name = 'dataset_name'
117- return dataset_info
121+
122+ dataset_builder = mock .create_autospec (dataset_builder_lib .DatasetBuilder )
123+ dataset_builder .info = dataset_info
124+
125+ return dataset_builder
118126
119127
120128@pytest .mark .parametrize (
121129 'data_source_cls' ,
122130 _DATA_SOURCE_CLS ,
123131)
124132def test_missing_split_raises_error (data_source_cls ):
125- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
133+ dataset_builder = create_dataset_builder (
134+ file_adapters .FileFormat .ARRAY_RECORD
135+ )
126136 with pytest .raises (
127137 ValueError ,
128138 match = "Unknown split 'doesnotexist'." ,
129139 ):
130- data_source_cls (dataset_info , split = 'doesnotexist' )
140+ data_source_cls (dataset_builder , split = 'doesnotexist' )
131141
132142
133143@pytest .mark .usefixtures (* _FIXTURES )
@@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
136146 _DATA_SOURCE_CLS ,
137147)
138148def test_repr_returns_meaningful_string_without_decoders (data_source_cls ):
139- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
140- source = data_source_cls (dataset_info , split = 'train' )
149+ dataset_builder = create_dataset_builder (
150+ file_adapters .FileFormat .ARRAY_RECORD
151+ )
152+ source = data_source_cls (dataset_builder , split = 'train' )
141153 name = data_source_cls .__name__
142154 assert (
143155 repr (source ) == f"{ name } (name=dataset_name, split='train', decoders=None)"
@@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
150162 _DATA_SOURCE_CLS ,
151163)
152164def test_repr_returns_meaningful_string_with_decoders (data_source_cls ):
153- dataset_info = create_dataset_info (file_adapters .FileFormat .ARRAY_RECORD )
165+ dataset_builder = create_dataset_builder (
166+ file_adapters .FileFormat .ARRAY_RECORD
167+ )
154168 source = data_source_cls (
155- dataset_info ,
169+ dataset_builder ,
156170 split = 'train' ,
157171 decoders = {'my_feature' : decode .SkipDecoding ()},
158172 )
@@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
181195 file_instructions = mock_array_record_data_source .call_args_list [1 ].args [0 ]
182196 assert file_instructions [0 ].skip == 0
183197 assert file_instructions [0 ].take == 30000
198+
199+
200+ # PyGrain requires that data sources are picklable.
201+ @pytest .mark .parametrize (
202+ 'file_format' ,
203+ file_adapters .FileFormat .with_random_access (),
204+ )
205+ @pytest .mark .parametrize ('pickle_module' , [pickle , cloudpickle ])
206+ def test_data_source_is_picklable_after_use (file_format , pickle_module ):
207+ with tfds .testing .tmp_dir () as data_dir :
208+ builder = tfds .testing .DummyDataset (data_dir = data_dir )
209+ builder .download_and_prepare (file_format = file_format )
210+ data_source = builder .as_data_source (split = 'train' )
211+ assert data_source [0 ] == {'id' : 0 }
212+ assert pickle_module .loads (pickle_module .dumps (data_source ))[0 ] == {'id' : 0 }
0 commit comments