@@ -83,13 +83,25 @@ class PickableDataSourceMock(mock.MagicMock):
8383 """Makes MagicMock pickable in order to work with multiprocessing in Grain."""
8484
8585 def __getstate__ (self ):
86- return {'num_examples' : len (self ), 'generator' : self ._generator }
86+ return {
87+ 'num_examples' : len (self ),
88+ 'generator' : self ._generator ,
89+ 'serialize_example' : self ._serialize_example ,
90+ }
8791
8892 def __setstate__ (self , state ):
89- num_examples , generator = state ['num_examples' ], state ['generator' ]
93+ num_examples , generator , serialize_example = (
94+ state ['num_examples' ],
95+ state ['generator' ],
96+ state ['serialize_example' ],
97+ )
9098 self .__len__ .return_value = num_examples
91- self .__getitem__ = functools .partial (_getitem , generator = generator )
92- self .__getitems__ = functools .partial (_getitems , generator = generator )
99+ self .__getitem__ = functools .partial (
100+ _getitem , generator = generator , serialize_example = serialize_example
101+ )
102+ self .__getitems__ = functools .partial (
103+ _getitems , generator = generator , serialize_example = serialize_example
104+ )
93105
94106 def __reduce__ (self ):
95107 return (PickableDataSourceMock , (), self .__getstate__ ())
@@ -99,50 +111,33 @@ def _getitem(
99111 self ,
100112 record_key : int ,
101113 generator : RandomFakeGenerator ,
102- serialized : bool = False ,
114+ serialize_example = None ,
103115) -> Any :
104116 """Function to overwrite __getitem__ in data sources."""
117+ del self
105118 example = generator [record_key ]
106- if serialized :
119+ if serialize_example :
107120 # Return serialized raw bytes
108- return self . dataset_info . features . serialize_example (example )
121+ return serialize_example (example )
109122 return example
110123
111124
112125def _getitems (
113126 self ,
114127 record_keys : Sequence [int ],
115128 generator : RandomFakeGenerator ,
116- serialized : bool = False ,
129+ serialize_example = None ,
117130) -> Sequence [Any ]:
118131 """Function to overwrite __getitems__ in data sources."""
119132 items = [
120- _getitem (self , record_key , generator , serialized = serialized )
133+ _getitem (self , record_key , generator , serialize_example = serialize_example )
121134 for record_key in record_keys
122135 ]
123- if serialized :
136+ if serialize_example :
124137 return np .array (items )
125138 return items
126139
127140
128- def _deserialize_example_np (serialized_example , * , decoders = None ):
129- """Function to overwrite dataset_info.features.deserialize_example_np.
130-
131- Warning: this has to be defined in the outer scope in order for the function
132- to be pickable.
133-
134- Args:
135- serialized_example: the example to deserialize.
136- decoders: optional decoders.
137-
138- Returns:
139- The serialized example, because deserialization is taken care by
140- RandomFakeGenerator.
141- """
142- del decoders
143- return serialized_example
144-
145-
146141class MockPolicy (enum .Enum ):
147142 """Strategy to use with `tfds.testing.mock_data` to mock the dataset.
148143
@@ -385,21 +380,27 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
385380 # Force ARRAY_RECORD as the default file_format.
386381 return_value = file_adapters .FileFormat .ARRAY_RECORD ,
387382 ):
388- self . info . features . deserialize_example_np = _deserialize_example_np
383+ # Make mock_data_source pickable with a given len:
389384 mock_data_source .return_value .__len__ .return_value = num_examples
385+ # Make mock_data_source pickable with a given generator:
390386 mock_data_source .return_value ._generator = ( # pylint:disable=protected-access
391387 generator
392388 )
389+ # Make mock_data_source pickable with a given serialize_example:
390+ mock_data_source .return_value ._serialize_example = ( # pylint:disable=protected-access
391+ self .info .features .serialize_example
392+ )
393+ serialize_example = self .info .features .serialize_example
393394 mock_data_source .return_value .__getitem__ = functools .partial (
394- _getitem , generator = generator
395+ _getitem , generator = generator , serialize_example = serialize_example
395396 )
396397 mock_data_source .return_value .__getitems__ = functools .partial (
397- _getitems , generator = generator
398+ _getitems , generator = generator , serialize_example = serialize_example
398399 )
399400
400401 def build_single_data_source (split ):
401402 single_data_source = array_record .ArrayRecordDataSource (
402- dataset_info = self . info , split = split , decoders = decoders
403+ dataset_builder = self , split = split , decoders = decoders
403404 )
404405 return single_data_source
405406
0 commit comments