Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6692,9 +6692,15 @@ def iter_random_indices():
return concatenated_datasets.select(indices, **kwargs)


def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: int) -> Dataset:
def _split_by_node_map_style_dataset(
dataset: Dataset,
rank: int,
world_size: int,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> Dataset:
"""
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each
rank having the same number of examples thanks to the `stopping_strategy`.
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.

Expand All @@ -6709,7 +6715,15 @@ def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: in
Returns:
[`Dataset`]: The dataset to be used on the node at rank `rank`.
"""
return dataset.shard(num_shards=world_size, index=rank, contiguous=True)
shard = dataset.shard(num_shards=world_size, index=rank, contiguous=True)
# Make sure all the shards have the same number of examples:
# - first_exhausted: len() = len(dataset) // world_size
# - all_exhausted: len() = len(dataset) // world_size + 1
if len(shard) == len(dataset) // world_size + 1 and stopping_strategy == "first_exhausted":
shard = shard.select(range(len(dataset) // world_size))
if len(shard) == len(dataset) // world_size and stopping_strategy == "all_exhausted":
shard = _concatenate_map_style_datasets([shard, shard.select([0])])
return shard


# This is outside Dataset.filter as it needs to be picklable for multiprocessing
Expand Down
35 changes: 27 additions & 8 deletions src/datasets/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Literal, TypeVar

from .arrow_dataset import Dataset, _split_by_node_map_style_dataset
from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset
Expand All @@ -7,20 +7,35 @@
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)


def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
def split_dataset_by_node(
dataset: DatasetType,
rank: int,
world_size: int,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> DatasetType:
"""
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`, with each
rank having the same number of examples thanks to the `stopping_strategy`.

The stopping strategy allows each node to have the same number of examples:

- "first_exhausted": stop when the first node runs of of data, and discard the extra data in the other nodes
- "all_exhausted": stop when the last node runs out of data, and other nodes may reuse their data to compensate

For map-style datasets:

Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
This doesn't need communication between nodes, since each node knows how many examples
are available and can discard or reuse up to one example accordingly.

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
The shards are evenly assigned across the nodes.
To maximize data loading throughput, each nodes has its own data and there is no overlap between nodes.
The stopping strategy has less impact at the end of training if the dataset has a number of shards that is
a factor of `world_size` (e.g. if `dataset.num_shards % world_size == 0`), since each node has roughly
the same amount of data available. Nodes communicate using torch distributed to decide when to stop.

Args:
dataset ([`Dataset`] or [`IterableDataset`]):
Expand All @@ -34,6 +49,10 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
[`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`.
"""
if isinstance(dataset, Dataset):
return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size)
return _split_by_node_map_style_dataset(
dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy
)
else:
return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size)
return _split_by_node_iterable_dataset(
dataset, rank=rank, world_size=world_size, stopping_strategy=stopping_strategy
)
134 changes: 113 additions & 21 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,95 @@ def num_shards(self) -> int:
return self.ex_iterable.num_shards


class SyncedDistributedExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
ex_iterable: _BaseExamplesIterable,
rank: int,
world_size: int,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
super().__init__()
self.ex_iterable = ex_iterable
self.rank = rank
self.world_size = world_size
self.stopping_strategy = stopping_strategy
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
import torch

self.bool_strategy_func = torch.all if stopping_strategy == "all_exhausted" else torch.any

@property
def iter_arrow(self):
if self.ex_iterable.iter_arrow:
return self._iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

@property
def features(self):
return self.ex_iterable.features

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
return self._state_dict

def __iter__(self):
import torch
import torch.distributed as dist

is_exhausted = torch.zeros(self.world_size, dtype=torch.bool)
while True:
for key, example in self.ex_iterable:
yield key, example
dist.all_reduce(is_exhausted)
if self.bool_strategy_func(is_exhausted):
return
is_exhausted[self.rank] = True
Comment on lines +2156 to +2160
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this doesn't yield one sample more if the dataset is exhausted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, let me see

if self._state_dict is not None:
self._state_dict = self.ex_iterable._init_state_dict()

def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
import torch
import torch.distributed as dist

is_exhausted = torch.zeros(self.world_size, dtype=torch.bool)
while True:
for key, pa_table in self.ex_iterable._iter_arrow():
yield key, pa_table
dist.all_reduce(is_exhausted)
if self.bool_strategy_func(is_exhausted):
return
is_exhausted[self.rank] = True
if self._state_dict is not None:
self._state_dict = self.ex_iterable._init_state_dict()

def shuffle_data_sources(self, generator: np.random.Generator) -> "SyncedDistributedExamplesIterable":
"""Shuffle the wrapped examples iterable."""
return SyncedDistributedExamplesIterable(
self.ex_iterable.shuffle_data_sources(generator),
rank=self.rank,
world_size=self.world_size,
stopping_strategy=self.stopping_strategy,
)

def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SyncedDistributedExamplesIterable":
"""Keep only the requested shard."""
return SyncedDistributedExamplesIterable(
self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
rank=self.rank,
world_size=self.world_size,
stopping_strategy=self.stopping_strategy,
)

@property
def num_shards(self) -> int:
return self.ex_iterable.num_shards


@dataclass
class ShufflingConfig:
generator: np.random.Generator
Expand All @@ -2120,6 +2209,7 @@ class ShufflingConfig:
class DistributedConfig:
rank: int
world_size: int
stopping_strategy: Literal["first_exhausted", "all_exhausted"]


def _maybe_add_torch_iterable_dataset_parent_class(cls):
Expand Down Expand Up @@ -2466,6 +2556,7 @@ def _prepare_ex_iterable_for_iteration(
self._formatting
and (ex_iterable.iter_arrow or self._formatting.is_table)
or (self.features and ex_iterable.features != self.features)
or self._distributed
):
ex_iterable = RebatchedArrowExamplesIterable(
ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch
Expand All @@ -2478,25 +2569,21 @@ def _prepare_ex_iterable_for_iteration(
if self._distributed:
rank = self._distributed.rank
world_size = self._distributed.world_size
if ex_iterable.num_shards % world_size == 0:
if self._is_main_process():
num_shards_per_node = ex_iterable.num_shards // world_size
plural = "s" if num_shards_per_node > 1 else ""
logger.info(
f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
)
ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False)
else:
if self._is_main_process():
logger.info(
f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration."
)
logger.info(
f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}"
)
ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)
if self._is_main_process():
num_shards_per_node = ex_iterable.num_shards // world_size
if ex_iterable.num_shards % world_size == 0:
num_shards_per_node = f"{num_shards_per_node}-{num_shards_per_node + 1}"
plural = "s" if str(num_shards_per_node) != "1" else ""
logger.info(
f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
)
ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False)
ex_iterable = SyncedDistributedExamplesIterable(
ex_iterable,
rank=self._distributed.rank,
world_size=self._distributed.world_size,
stopping_strategy=self._distributed.stopping_strategy,
)

if self._formatting or (self.features and ex_iterable.features != self.features):
ex_iterable = FormattedExamplesIterable(
Expand Down Expand Up @@ -4662,7 +4749,12 @@ def _interleave_iterable_datasets(
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)


def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset:
def _split_by_node_iterable_dataset(
dataset: IterableDataset,
rank: int,
world_size: int,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> IterableDataset:
"""
Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`.

Expand All @@ -4684,7 +4776,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
if dataset._distributed:
rank = world_size * dataset._distributed.rank + rank
world_size = world_size * dataset._distributed.world_size
distributed = DistributedConfig(rank=rank, world_size=world_size)
distributed = DistributedConfig(rank=rank, world_size=world_size, stopping_strategy=stopping_strategy)
return IterableDataset(
ex_iterable=dataset._ex_iterable,
info=dataset._info.copy(),
Expand Down
Loading