@@ -120,26 +120,25 @@ def _getitems(
120120 _getitem (self , record_key , generator , serialized = serialized )
121121 for record_key in record_keys
122122 ]
123- if serialized :
124- return np .array (items )
125- return items
123+ return np .asarray (items )
126124
127125
128- def _deserialize_example_np (serialized_example , * , decoders = None ):
126+ def _deserialize_example_np (self , serialized_example , * , decoders = None ):
129127 """Function to overwrite dataset_info.features.deserialize_example_np.
130128
131129 Warning: this has to be defined in the outer scope in order for the function
132130 to be pickable.
133131
134132 Args:
133+ self: the dataset builder.
135134 serialized_example: the example to deserialize.
136135 decoders: optional decoders.
137136
138137 Returns:
139138 The serialized example, because deserialization is taken care by
140139 RandomFakeGenerator.
141140 """
142- del decoders
141+ del self , decoders
143142 return serialized_example
144143
145144
@@ -173,6 +172,7 @@ def mock_data(
173172 as_data_source_fn : Optional [Callable [..., Sequence [Any ]]] = None ,
174173 data_dir : Optional [str ] = None ,
175174 mock_array_record_data_source : Optional [PickableDataSourceMock ] = None ,
175+ use_in_multiprocessing : bool = False ,
176176) -> Iterator [None ]:
177177 """Mock tfds to generate random data.
178178
@@ -262,6 +262,10 @@ def as_dataset(self, *args, **kwargs):
262262 mock_array_record_data_source: Overwrite a mock for the underlying
263263 ArrayRecord data source if it is used. Note: If used the same mock will be
264264 used for all data sources loaded within this context.
265+ use_in_multiprocessing: If True, the mock will use a multiprocessing-safe
266+ approach to generate the data. It's notably useful for PyGrain. The goal
267+ is to migrate the codebase to this mode by default. Find a more detailed
268+ explanation of this parameter in a comment in the code below.
265269
266270 Yields:
267271 None
@@ -361,9 +365,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
361365 if split is None :
362366 split = {s : s for s in self .info .splits }
363367
364- generator_cls , features , _ , _ = _get_fake_data_components (
365- decoders , self .info .features
366- )
368+ features = self .info .features
369+ if use_in_multiprocessing :
370+ # In multiprocessing, we generate serialized data. The data is then
371+ # re-deserialized by the feature as it would normally happen in TFDS. In
372+ # this approach, we don't need to monkey-patch workers to propagate the
373+ # information that deserialize_example_np should be a no-op. Indeed, doing
374+ # so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users
375+ # of tfds.testing.mock_data in the codebase started relying on the
376+ # function not serializing (for example, they don't have TensorFlow in
377+ # their dependency), so we cannot have use_in_multiprocessing by default.
378+ # ┌─────────────┐
379+ # │ Main process│
380+ # └─┬──────┬────┘
381+ # ┌───────▼─┐ ┌─▼───────┐
382+ # │ worker1 │ │ worker2 │ ...
383+ # └───────┬─┘ └─┬───────┘
384+ # serialized data by the generator
385+ # ┌───────▼─┐ ┌─▼───────┐
386+ # │ tfds 1 │ │ tfds 2 │ ...
387+ # └───────┬─┘ └─┬───────┘
388+ # deserialized data
389+ generator_cls = SerializedRandomFakeGenerator
390+ else :
391+ # We generate already deserialized data with the generator.
392+ generator_cls , _ , _ , _ = _get_fake_data_components (decoders , features )
367393 generator = generator_cls (features , num_examples )
368394
369395 if actual_policy == MockPolicy .USE_CODE :
@@ -385,7 +411,6 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
385411 # Force ARRAY_RECORD as the default file_format.
386412 return_value = file_adapters .FileFormat .ARRAY_RECORD ,
387413 ):
388- self .info .features .deserialize_example_np = _deserialize_example_np
389414 mock_data_source .return_value .__len__ .return_value = num_examples
390415 mock_data_source .return_value ._generator = ( # pylint:disable=protected-access
391416 generator
@@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
399424
400425 def build_single_data_source (split ):
401426 single_data_source = array_record .ArrayRecordDataSource (
402- dataset_info = self . info , split = split , decoders = decoders
427+ dataset_builder = self , split = split , decoders = decoders
403428 )
404429 return single_data_source
405430
@@ -463,6 +488,10 @@ def new_builder_from_files(*args, **kwargs):
463488 f'{ core } .dataset_builder.FileReaderBuilder._as_dataset' ,
464489 as_dataset_fn ,
465490 ),
491+ (
492+ f'{ core } .features.top_level_feature.TopLevelFeature.deserialize_example_np' ,
493+ _deserialize_example_np ,
494+ ),
466495 ]:
467496 stack .enter_context (mock .patch (path , mocked_fn ))
468497 yield
0 commit comments