1717
1818from __future__ import annotations
1919
20- from collections .abc import Sequence
20+ from collections .abc import Iterable , Iterator , Mapping , Sequence
2121import dataclasses
2222import difflib
23- import json
2423import posixpath
2524import re
2625import textwrap
2726import typing
28- from typing import Any , Callable , Dict , Iterable , Iterator , List , Mapping , Optional , Type
27+ from typing import Any , Callable , Optional , Type
2928
3029from absl import logging
3130from etils import epath
3231from tensorflow_datasets .core import community
33- from tensorflow_datasets .core import constants
3432from tensorflow_datasets .core import dataset_builder
3533from tensorflow_datasets .core import dataset_collection_builder
3634from tensorflow_datasets .core import decode
4038from tensorflow_datasets .core import read_only_builder
4139from tensorflow_datasets .core import registered
4240from tensorflow_datasets .core import splits as splits_lib
43- from tensorflow_datasets .core import utils
4441from tensorflow_datasets .core import visibility
4542from tensorflow_datasets .core .dataset_builders import huggingface_dataset_builder # pylint:disable=unused-import
4643from tensorflow_datasets .core .download import util
4946from tensorflow_datasets .core .utils import py_utils
5047from tensorflow_datasets .core .utils import read_config as read_config_lib
5148from tensorflow_datasets .core .utils import type_utils
52- from tensorflow_datasets .core .utils import version
49+ from tensorflow_datasets .core .utils import version as version_lib
5350from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
5451
5552# pylint: disable=logging-format-interpolation
7471def list_builders (
7572 * ,
7673 with_community_datasets : bool = True ,
77- ) -> List [str ]:
74+ ) -> list [str ]:
7875 """Returns the string names of all `tfds.core.DatasetBuilder`s."""
7976 datasets = registered .list_imported_builders ()
8077 if with_community_datasets :
@@ -83,7 +80,7 @@ def list_builders(
8380 return datasets
8481
8582
86- def list_dataset_collections () -> List [str ]:
83+ def list_dataset_collections () -> list [str ]:
8784 """Returns the string names of all `tfds.core.DatasetCollectionBuilder`s."""
8885 collections = registered .list_imported_dataset_collections ()
8986 return collections
@@ -124,7 +121,7 @@ def builder_cls(name: str) -> Type[dataset_builder.DatasetBuilder]:
124121 cls = typing .cast (Type [dataset_builder .DatasetBuilder ], cls )
125122 return cls
126123 except registered .DatasetNotFoundError :
127- _add_list_builders_context (name = ds_name ) # pytype: disable=bad-return-type
124+ _add_list_builders_context (name = ds_name )
128125 raise
129126
130127
@@ -173,6 +170,9 @@ def builder(
173170 name , builder_kwargs = naming .parse_builder_name_kwargs (
174171 name , ** builder_kwargs
175172 )
173+ # Make sure that `data_dir` is not set to an empty string or None.
174+ if 'data_dir' in builder_kwargs and not builder_kwargs ['data_dir' ]:
175+ builder_kwargs .pop ('data_dir' )
176176
177177 def get_dataset_repr () -> str :
178178 return f'dataset "{ name } ", builder_kwargs "{ builder_kwargs } "'
@@ -263,7 +263,7 @@ class DatasetCollectionLoader:
263263
264264 collection : dataset_collection_builder .DatasetCollection
265265 requested_version : Optional [str ] = None
266- loader_kwargs : Optional [ Dict [ str , Any ]] = None
266+ loader_kwargs : dict [ str , Any ] | None = None
267267
268268 def __post_init__ (self ):
269269 self .datasets = self .collection .get_collection (self .requested_version )
@@ -298,14 +298,14 @@ def get_dataset_info(self, dataset_name: str):
298298 )
299299 return info
300300
301- def set_loader_kwargs (self , loader_kwargs : Dict [str , Any ]):
301+ def set_loader_kwargs (self , loader_kwargs : dict [str , Any ]):
302302 self .loader_kwargs = loader_kwargs
303303
304304 def load_dataset (
305305 self ,
306306 dataset : str ,
307307 split : Optional [Tree [splits_lib .SplitArg ]] = None ,
308- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
308+ loader_kwargs : dict [ str , Any ] | None = None ,
309309 ) -> Mapping [str , tf .data .Dataset ]:
310310 """Loads the named dataset from a dataset collection by calling `tfds.load`.
311311
@@ -388,7 +388,7 @@ def load_datasets(
388388 self ,
389389 datasets : Iterable [str ],
390390 split : Optional [Tree [splits_lib .SplitArg ]] = None ,
391- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
391+ loader_kwargs : dict [ str , Any ] | None = None ,
392392 ) -> Mapping [str , Mapping [str , tf .data .Dataset ]]:
393393 """Loads a number of datasets from the dataset collection.
394394
@@ -418,7 +418,7 @@ def load_datasets(
418418 def load_all_datasets (
419419 self ,
420420 split : Optional [Tree [splits_lib .SplitArg ]] = None ,
421- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
421+ loader_kwargs : dict [ str , Any ] | None = None ,
422422 ) -> Mapping [str , Mapping [str , tf .data .Dataset ]]:
423423 """Loads all datasets of a collection.
424424
@@ -440,7 +440,7 @@ def load_all_datasets(
440440@tfds_logging .dataset_collection ()
441441def dataset_collection (
442442 name : str ,
443- loader_kwargs : Optional [Dict [str , Any ]] = None ,
443+ loader_kwargs : Optional [dict [str , Any ]] = None ,
444444) -> DatasetCollectionLoader :
445445 """Instantiates a DatasetCollectionLoader.
446446
@@ -500,7 +500,7 @@ def _fetch_builder(
500500def _download_and_prepare_builder (
501501 dbuilder : dataset_builder .DatasetBuilder ,
502502 download : bool ,
503- download_and_prepare_kwargs : Optional [Dict [str , Any ]],
503+ download_and_prepare_kwargs : Optional [dict [str , Any ]],
504504) -> None :
505505 """Downloads and prepares the dataset builder if necessary."""
506506 if dbuilder .is_prepared ():
@@ -594,7 +594,7 @@ def load(
594594 split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
595595 'test']`, `'train[80%:]'`,...). See our [split API
596596 guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
597- all splits in a `Dict [Split, tf.data.Dataset]`
597+ all splits in a `dict [Split, tf.data.Dataset]`
598598 data_dir: directory to read/write data. Defaults to the value of the
599599 environment variable TFDS_DATA_DIR, if set, otherwise falls back to
600600 '~/tensorflow_datasets'.
@@ -776,7 +776,7 @@ def data_source(
776776 split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
777777 'test']`, `'train[80%:]'`,...). See our [split API
778778 guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
779- all splits in a `Dict [Split, Sequence]`
779+ all splits in a `dict [Split, Sequence]`
780780 data_dir: directory to read/write data. Defaults to the value of the
781781 environment variable TFDS_DATA_DIR, if set, otherwise falls back to
782782 '~/tensorflow_datasets'.
@@ -832,11 +832,11 @@ def data_source(
832832
833833
834834def _get_all_versions (
835- current_version : version .Version | None ,
836- extra_versions : Iterable [version .Version ],
835+ current_version : version_lib .Version | None ,
836+ extra_versions : Iterable [version_lib .Version ],
837837 current_version_only : bool ,
838- ) -> Iterable [str ]:
839- """Returns the list of all current versions."""
838+ ) -> set [str ]:
839+ """Returns the set of all current versions."""
840840 # Merge current version with all extra versions
841841 version_list = [current_version ] if current_version else []
842842 if not current_version_only :
@@ -881,7 +881,7 @@ def _iter_full_names(current_version_only: bool) -> Iterator[str]:
881881 yield full_name
882882
883883
884- def list_full_names (current_version_only : bool = False ) -> List [str ]:
884+ def list_full_names (current_version_only : bool = False ) -> list [str ]:
885885 """Lists all registered datasets full_names.
886886
887887 Args:
@@ -896,7 +896,7 @@ def list_full_names(current_version_only: bool = False) -> List[str]:
896896def single_full_names (
897897 builder_name : str ,
898898 current_version_only : bool = True ,
899- ) -> List [str ]:
899+ ) -> list [str ]:
900900 """Returns the list `['ds/c0/v0',...]` or `['ds/v']` for a single builder."""
901901 return sorted (
902902 _iter_single_full_names (
0 commit comments