@@ -83,25 +83,13 @@ class PickableDataSourceMock(mock.MagicMock):
8383 """Makes MagicMock pickable in order to work with multiprocessing in Grain."""
8484
8585 def __getstate__ (self ):
86- return {
87- 'num_examples' : len (self ),
88- 'generator' : self ._generator ,
89- 'serialize_example' : self ._serialize_example ,
90- }
86+ return {'num_examples' : len (self ), 'generator' : self ._generator }
9187
9288 def __setstate__ (self , state ):
93- num_examples , generator , serialize_example = (
94- state ['num_examples' ],
95- state ['generator' ],
96- state ['serialize_example' ],
97- )
89+ num_examples , generator = state ['num_examples' ], state ['generator' ]
9890 self .__len__ .return_value = num_examples
99- self .__getitem__ = functools .partial (
100- _getitem , generator = generator , serialize_example = serialize_example
101- )
102- self .__getitems__ = functools .partial (
103- _getitems , generator = generator , serialize_example = serialize_example
104- )
91+ self .__getitem__ = functools .partial (_getitem , generator = generator )
92+ self .__getitems__ = functools .partial (_getitems , generator = generator )
10593
10694 def __reduce__ (self ):
10795 return (PickableDataSourceMock , (), self .__getstate__ ())
@@ -111,33 +99,50 @@ def _getitem(
11199 self ,
112100 record_key : int ,
113101 generator : RandomFakeGenerator ,
114- serialize_example = None ,
102+ serialized : bool = False ,
115103) -> Any :
116104 """Function to overwrite __getitem__ in data sources."""
117- del self
118105 example = generator [record_key ]
119- if serialize_example :
106+ if serialized :
120107 # Return serialized raw bytes
121- return serialize_example (example )
108+ return self . dataset_info . features . serialize_example (example )
122109 return example
123110
124111
125112def _getitems (
126113 self ,
127114 record_keys : Sequence [int ],
128115 generator : RandomFakeGenerator ,
129- serialize_example = None ,
116+ serialized : bool = False ,
130117) -> Sequence [Any ]:
131118 """Function to overwrite __getitems__ in data sources."""
132119 items = [
133- _getitem (self , record_key , generator , serialize_example = serialize_example )
120+ _getitem (self , record_key , generator , serialized = serialized )
134121 for record_key in record_keys
135122 ]
136- if serialize_example :
123+ if serialized :
137124 return np .array (items )
138125 return items
139126
140127
128+ def _deserialize_example_np (serialized_example , * , decoders = None ):
129+ """Function to overwrite dataset_info.features.deserialize_example_np.
130+
131+ Warning: this has to be defined in the outer scope in order for the function
132+ to be pickable.
133+
134+ Args:
135+ serialized_example: the example to deserialize.
136+ decoders: optional decoders.
137+
138+ Returns:
139+ The serialized example, because deserialization is taken care by
140+ RandomFakeGenerator.
141+ """
142+ del decoders
143+ return serialized_example
144+
145+
141146class MockPolicy (enum .Enum ):
142147 """Strategy to use with `tfds.testing.mock_data` to mock the dataset.
143148
@@ -380,27 +385,21 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
380385 # Force ARRAY_RECORD as the default file_format.
381386 return_value = file_adapters .FileFormat .ARRAY_RECORD ,
382387 ):
383- # Make mock_data_source pickable with a given len:
388+ self . info . features . deserialize_example_np = _deserialize_example_np
384389 mock_data_source .return_value .__len__ .return_value = num_examples
385- # Make mock_data_source pickable with a given generator:
386390 mock_data_source .return_value ._generator = ( # pylint:disable=protected-access
387391 generator
388392 )
389- # Make mock_data_source pickable with a given serialize_example:
390- mock_data_source .return_value ._serialize_example = ( # pylint:disable=protected-access
391- self .info .features .serialize_example
392- )
393- serialize_example = self .info .features .serialize_example
394393 mock_data_source .return_value .__getitem__ = functools .partial (
395- _getitem , generator = generator , serialize_example = serialize_example
394+ _getitem , generator = generator
396395 )
397396 mock_data_source .return_value .__getitems__ = functools .partial (
398- _getitems , generator = generator , serialize_example = serialize_example
397+ _getitems , generator = generator
399398 )
400399
401400 def build_single_data_source (split ):
402401 single_data_source = array_record .ArrayRecordDataSource (
403- dataset_builder = self , split = split , decoders = decoders
402+ dataset_info = self . info , split = split , decoders = decoders
404403 )
405404 return single_data_source
406405
0 commit comments