diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 250442a0ed8..61c5a44d5d4 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -6692,9 +6692,15 @@ def iter_random_indices(): return concatenated_datasets.select(indices, **kwargs) -def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: int) -> Dataset: +def _split_by_node_map_style_dataset( + dataset: Dataset, + rank: int, + world_size: int, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", +) -> Dataset: """ - Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. + Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each + rank having the same number of examples thanks to the `stopping_strategy`. Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. To maximize data loading throughput, chunks are made of contiguous data on disk if possible. @@ -6709,7 +6715,15 @@ def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: in Returns: [`Dataset`]: The dataset to be used on the node at rank `rank`. """ - return dataset.shard(num_shards=world_size, index=rank, contiguous=True) + shard = dataset.shard(num_shards=world_size, index=rank, contiguous=True) + # Make sure all the shards have the same number of examples: + # - first_exhausted: len() = len(dataset) // world_size + # - all_exhausted: len() = len(dataset) // world_size + 1 + if len(shard) == len(dataset) // world_size + 1 and stopping_strategy == "first_exhausted": + shard = shard.select(range(len(dataset) // world_size)) + if len(shard) == len(dataset) // world_size and stopping_strategy == "all_exhausted": + shard = _concatenate_map_style_datasets([shard, shard.select([0])]) + return shard # This is outside Dataset.filter as it needs to be picklable for multiprocessing diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py index 4697948f342..e1172544850 100644 --- a/src/datasets/distributed.py +++ b/src/datasets/distributed.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import Literal, TypeVar from .arrow_dataset import Dataset, _split_by_node_map_style_dataset from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset @@ -7,20 +7,35 @@ DatasetType = TypeVar("DatasetType", Dataset, IterableDataset) -def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType: +def split_dataset_by_node( + dataset: DatasetType, + rank: int, + world_size: int, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", +) -> DatasetType: """ - Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. + Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each + rank having the same number of examples thanks to the `stopping_strategy`. + + The stopping strategy allows each node to have the same number of examples: + + - "first_exhausted": stop when the first node runs of of data, and discard the extra data in the other nodes + - "all_exhausted": stop when the last node runs out of data, and other nodes may reuse their data to compensate For map-style datasets: Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. To maximize data loading throughput, chunks are made of contiguous data on disk if possible. + This doesn't need communication between nodes, since each node knows how many examples + are available and can discard or reuse up to one example accordingly. For iterable datasets: - If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), - then the shards are evenly assigned across the nodes, which is the most optimized. - Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. + The shards are evenly assigned across the nodes. + To maximize data loading throughput, each nodes has its own data and there is no overlap between nodes. + The stopping strategy has less impact at the end of training if the dataset has a number of shards that is + a factor of `world_size` (e.g. if `dataset.num_shards % world_size == 0`), since each node has roughly + the same amount of data available. Nodes communicate using torch distributed to decide when to stop. Args: dataset ([`Dataset`] or [`IterableDataset`]): @@ -34,6 +49,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D [`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`. """ if isinstance(dataset, Dataset): - return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size) + return _split_by_node_map_style_dataset( + dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy + ) else: - return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size) + return _split_by_node_iterable_dataset( + dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy + ) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 9ac842d2c22..35a0df9a225 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2110,6 +2110,95 @@ def num_shards(self) -> int: return self.ex_iterable.num_shards +class SyncedDistributedExamplesIterable(_BaseExamplesIterable): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + rank: int, + world_size: int, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", + ): + super().__init__() + self.ex_iterable = ex_iterable + self.rank = rank + self.world_size = world_size + self.stopping_strategy = stopping_strategy + # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted + # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once + import torch + + self.bool_strategy_func = torch.all if stopping_strategy == "all_exhausted" else torch.any + + @property + def iter_arrow(self): + if self.ex_iterable.iter_arrow: + return self._iter_arrow + + @property + def is_typed(self): + return self.ex_iterable.is_typed + + @property + def features(self): + return self.ex_iterable.features + + def _init_state_dict(self) -> dict: + self._state_dict = self.ex_iterable._init_state_dict() + return self._state_dict + + def __iter__(self): + import torch + import torch.distributed as dist + + is_exhausted = torch.zeros(self.world_size, dtype=torch.bool) + while True: + for key, example in self.ex_iterable: + yield key, example + dist.all_reduce(is_exhausted) + if self.bool_strategy_func(is_exhausted): + return + is_exhausted[self.rank] = True + if self._state_dict is not None: + self._state_dict = self.ex_iterable._init_state_dict() + + def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]: + import torch + import torch.distributed as dist + + is_exhausted = torch.zeros(self.world_size, dtype=torch.bool) + while True: + for key, pa_table in self.ex_iterable._iter_arrow(): + yield key, pa_table + dist.all_reduce(is_exhausted) + if self.bool_strategy_func(is_exhausted): + return + is_exhausted[self.rank] = True + if self._state_dict is not None: + self._state_dict = self.ex_iterable._init_state_dict() + + def shuffle_data_sources(self, generator: np.random.Generator) -> "SyncedDistributedExamplesIterable": + """Shuffle the wrapped examples iterable.""" + return SyncedDistributedExamplesIterable( + self.ex_iterable.shuffle_data_sources(generator), + rank=self.rank, + world_size=self.world_size, + stopping_strategy=self.stopping_strategy, + ) + + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SyncedDistributedExamplesIterable": + """Keep only the requested shard.""" + return SyncedDistributedExamplesIterable( + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + rank=self.rank, + world_size=self.world_size, + stopping_strategy=self.stopping_strategy, + ) + + @property + def num_shards(self) -> int: + return self.ex_iterable.num_shards + + @dataclass class ShufflingConfig: generator: np.random.Generator @@ -2120,6 +2209,7 @@ class ShufflingConfig: class DistributedConfig: rank: int world_size: int + stopping_strategy: Literal["first_exhausted", "all_exhausted"] def _maybe_add_torch_iterable_dataset_parent_class(cls): @@ -2466,6 +2556,7 @@ def _prepare_ex_iterable_for_iteration( self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table) or (self.features and ex_iterable.features != self.features) + or self._distributed ): ex_iterable = RebatchedArrowExamplesIterable( ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch @@ -2478,25 +2569,21 @@ def _prepare_ex_iterable_for_iteration( if self._distributed: rank = self._distributed.rank world_size = self._distributed.world_size - if ex_iterable.num_shards % world_size == 0: - if self._is_main_process(): - num_shards_per_node = ex_iterable.num_shards // world_size - plural = "s" if num_shards_per_node > 1 else "" - logger.info( - f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." - ) - ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False) - else: - if self._is_main_process(): - logger.info( - f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration." - ) - logger.info( - f"It is more optimized to distribute the dataset shards (or data sources) across nodes. " - f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. " - f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}" - ) - ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank) + if self._is_main_process(): + num_shards_per_node = ex_iterable.num_shards // world_size + if ex_iterable.num_shards % world_size == 0: + num_shards_per_node = f"{num_shards_per_node}-{num_shards_per_node + 1}" + plural = "s" if str(num_shards_per_node) != "1" else "" + logger.info( + f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." + ) + ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False) + ex_iterable = SyncedDistributedExamplesIterable( + ex_iterable, + rank=self._distributed.rank, + world_size=self._distributed.world_size, + stopping_strategy=self._distributed.stopping_strategy, + ) if self._formatting or (self.features and ex_iterable.features != self.features): ex_iterable = FormattedExamplesIterable( @@ -4662,7 +4749,12 @@ def _interleave_iterable_datasets( return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) -def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset: +def _split_by_node_iterable_dataset( + dataset: IterableDataset, + rank: int, + world_size: int, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", +) -> IterableDataset: """ Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`. @@ -4684,7 +4776,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s if dataset._distributed: rank = world_size * dataset._distributed.rank + rank world_size = world_size * dataset._distributed.world_size - distributed = DistributedConfig(rank=rank, world_size=world_size) + distributed = DistributedConfig(rank=rank, world_size=world_size, stopping_strategy=stopping_strategy) return IterableDataset( ex_iterable=dataset._ex_iterable, info=dataset._info.copy(),