|
36 | 36 | from etils import epath |
37 | 37 | from tensorflow_datasets.core import example_parser |
38 | 38 | from tensorflow_datasets.core import example_serializer |
| 39 | + from tensorflow_datasets.core import features as features_lib |
39 | 40 | from tensorflow_datasets.core import file_adapters |
40 | 41 | from tensorflow_datasets.core import hashing |
41 | 42 | from tensorflow_datasets.core import naming |
@@ -264,27 +265,34 @@ class ShardWriter: |
264 | 265 |
|
265 | 266 | def __init__( |
266 | 267 | self, |
| 268 | + features: features_lib.FeatureConnector, |
267 | 269 | serializer: example_serializer.Serializer, |
268 | 270 | example_writer: ExampleWriter, |
269 | 271 | ): |
270 | 272 | """Initializes Writer. |
271 | 273 |
|
272 | 274 | Args: |
| 275 | + features: the features of the dataset. |
273 | 276 | serializer: class that can serialize examples. |
274 | 277 | example_writer: class that writes examples to disk or elsewhere. |
275 | 278 | """ |
| 279 | + self._features = features |
276 | 280 | self._serializer = serializer |
277 | 281 | self._example_writer = example_writer |
278 | 282 |
|
| 283 | + def _serialize_example(self, example: Example) -> Any: |
| 284 | + """Encodes and serializes an example.""" |
| 285 | + return self._serializer.serialize_example( |
| 286 | + self._features.encode_example(example) |
| 287 | + ) |
| 288 | + |
279 | 289 | def write( |
280 | 290 | self, |
281 | 291 | examples: Iterable[KeyExample], |
282 | 292 | path: epath.Path, |
283 | 293 | ) -> int: |
284 | 294 | """Returns the number of examples written to the given path.""" |
285 | | - serialized_examples = [ |
286 | | - (k, self._serializer.serialize_example(v)) for k, v in examples |
287 | | - ] |
| 295 | + serialized_examples = [(k, self._serialize_example(v)) for k, v in examples] |
288 | 296 | self._example_writer.write(path=path, examples=serialized_examples) |
289 | 297 |
|
290 | 298 | return len(serialized_examples) |
|
0 commit comments