1717
1818import abc
1919import collections
20+ from collections .abc import Iterator
2021import contextlib
2122import functools
2223import importlib
2324import inspect
2425import os .path
25- from typing import ClassVar , Dict , Iterator , List , Type , Text , Tuple
26+ import time
27+ from typing import ClassVar , Type
2628
29+ from absl import logging
2730from etils import epath
2831from tensorflow_datasets .core import constants
2932from tensorflow_datasets .core import naming
3033from tensorflow_datasets .core import visibility
34+ import tensorflow_datasets .core .logging as _tfds_logging
35+ from tensorflow_datasets .core .logging import call_metadata as _call_metadata
3136from tensorflow_datasets .core .utils import py_utils
3237from tensorflow_datasets .core .utils import resource_utils
3338
3843# <str snake_cased_name, abstract DatasetBuilder subclass>
3944_ABSTRACT_DATASET_REGISTRY = {}
4045
41- # Keep track of Dict [str (module name), List [DatasetBuilder]]
46+ # Keep track of dict [str (module name), list [DatasetBuilder]]
4247# This is directly accessed by `tfds.community.builder_cls_from_module` when
4348# importing community packages.
4449_MODULE_TO_DATASETS = collections .defaultdict (list )
5156# <str snake_cased_name, abstract DatasetCollectionBuilder subclass>
5257_ABSTRACT_DATASET_COLLECTION_REGISTRY = {}
5358
54- # Keep track of Dict [str (module name), List [DatasetCollectionBuilder]]
59+ # Keep track of dict [str (module name), list [DatasetCollectionBuilder]]
5560_MODULE_TO_DATASET_COLLECTIONS = collections .defaultdict (list )
5661
5762# eg for dataset "foo": "tensorflow_datasets.datasets.foo.foo_dataset_builder".
@@ -80,6 +85,70 @@ def skip_registration() -> Iterator[None]:
8085 _skip_registration = False
8186
8287
88+ @functools .cache
89+ def _import_legacy_builders () -> None :
90+ """Imports legacy builders."""
91+ modules_to_import = [
92+ 'audio' ,
93+ 'graphs' ,
94+ 'image' ,
95+ 'image_classification' ,
96+ 'object_detection' ,
97+ 'nearest_neighbors' ,
98+ 'question_answering' ,
99+ 'd4rl' ,
100+ 'ranking' ,
101+ 'recommendation' ,
102+ 'rl_unplugged' ,
103+ 'rlds.datasets' ,
104+ 'robotics' ,
105+ 'robomimic' ,
106+ 'structured' ,
107+ 'summarization' ,
108+ 'text' ,
109+ 'text_simplification' ,
110+ 'time_series' ,
111+ 'translate' ,
112+ 'video' ,
113+ 'vision_language' ,
114+ ]
115+
116+ before_dataset_imports = time .time ()
117+ metadata = _call_metadata .CallMetadata ()
118+ metadata .start_time_micros = int (before_dataset_imports * 1e6 )
119+ try :
120+ # For builds that don't include all dataset builders, we don't want to fail
121+ # on import errors of dataset builders.
122+ try :
123+ for module in modules_to_import :
124+ importlib .import_module (f'tensorflow_datasets.{ module } ' )
125+ except (ImportError , ModuleNotFoundError ):
126+ pass
127+
128+ except Exception as exception : # pylint: disable=broad-except
129+ metadata .mark_error ()
130+ logging .exception (exception )
131+ finally :
132+ import_time_ms_dataset_builders = int (
133+ (time .time () - before_dataset_imports ) * 1000
134+ )
135+ metadata .mark_end ()
136+ _tfds_logging .tfds_import (
137+ metadata = metadata ,
138+ import_time_ms_tensorflow = 0 ,
139+ import_time_ms_dataset_builders = import_time_ms_dataset_builders ,
140+ )
141+
142+
143+ @functools .cache
144+ def _import_dataset_collections () -> None :
145+ """Imports dataset collections."""
146+ try :
147+ importlib .import_module ('tensorflow_datasets.dataset_collections' )
148+ except (ImportError , ModuleNotFoundError ):
149+ pass
150+
151+
83152# The implementation of this class follows closely RegisteredDataset.
84153class RegisteredDatasetCollection (abc .ABC ):
85154 """Subclasses will be registered and given a `name` property."""
@@ -129,23 +198,24 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
129198 _DATASET_COLLECTION_REGISTRY [cls .name ] = cls
130199
131200
132- def list_imported_dataset_collections () -> List [str ]:
201+ def list_imported_dataset_collections () -> list [str ]:
133202 """Returns the string names of all `tfds.core.DatasetCollection`s."""
134- all_dataset_collections = [
135- dataset_collection_name
136- for dataset_collection_name , dataset_collection_cls in _DATASET_COLLECTION_REGISTRY .items ()
137- ]
203+ _import_dataset_collections ()
204+ all_dataset_collections = list (_DATASET_COLLECTION_REGISTRY .keys ())
138205 return sorted (all_dataset_collections )
139206
140207
141208def is_dataset_collection (name : str ) -> bool :
209+ _import_dataset_collections ()
142210 return name in _DATASET_COLLECTION_REGISTRY
143211
144212
145213def imported_dataset_collection_cls (
146214 name : str ,
147215) -> Type [RegisteredDatasetCollection ]:
148216 """Returns the Registered dataset class."""
217+ _import_dataset_collections ()
218+
149219 if name in _ABSTRACT_DATASET_COLLECTION_REGISTRY :
150220 raise AssertionError (f'DatasetCollection { name } is an abstract class.' )
151221
@@ -224,8 +294,9 @@ def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
224294 return visibility .DatasetType .TFDS_PUBLIC .is_available ()
225295
226296
227- def list_imported_builders () -> List [str ]:
297+ def list_imported_builders () -> list [str ]:
228298 """Returns the string names of all `tfds.core.DatasetBuilder`s."""
299+ _import_legacy_builders ()
229300 all_builders = [
230301 builder_name
231302 for builder_name , builder_cls in _DATASET_REGISTRY .items ()
@@ -236,8 +307,8 @@ def list_imported_builders() -> List[str]:
236307
237308@functools .lru_cache (maxsize = None )
238309def _get_existing_dataset_packages (
239- datasets_dir : Text ,
240- ) -> Dict [ Text , Tuple [epath .Path , Text ]]:
310+ datasets_dir : str ,
311+ ) -> dict [ str , tuple [epath .Path , str ]]:
241312 """Returns existing datasets.
242313
243314 Args:
@@ -293,7 +364,12 @@ def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
293364 raise AssertionError (f'Dataset { name } is an abstract class.' )
294365
295366 if name not in _DATASET_REGISTRY :
296- raise DatasetNotFoundError (f'Dataset { name } not found.' )
367+ # Dataset not found in the registry, try to import legacy builders.
368+ # Dataset builders are imported lazily to avoid slowing down the startup
369+ # of the binary.
370+ _import_legacy_builders ()
371+ if name not in _DATASET_REGISTRY :
372+ raise DatasetNotFoundError (f'Dataset { name } not found.' )
297373
298374 builder_cls = _DATASET_REGISTRY [name ]
299375 if not _is_builder_available (builder_cls ):
0 commit comments