|
| 1 | +import itertools |
1 | 2 | import socket |
2 | 3 | from contextlib import contextmanager |
3 | 4 | from functools import wraps |
4 | | -from typing import Any, Callable, List, Mapping, Optional, Tuple, Union |
| 5 | +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union |
5 | 6 |
|
6 | 7 | import torch |
7 | 8 |
|
@@ -350,29 +351,80 @@ def all_reduce( |
350 | 351 | return _model.all_reduce(tensor, op, group=group) |
351 | 352 |
|
352 | 353 |
|
| 354 | +def _all_gather_tensors_with_shapes( |
| 355 | + tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None |
| 356 | +) -> List[torch.Tensor]: |
| 357 | + if _need_to_sync and isinstance(_model, _SerialModel): |
| 358 | + sync(temporary=True) |
| 359 | + |
| 360 | + if isinstance(group, list) and all(isinstance(item, int) for item in group): |
| 361 | + group = _model.new_group(group) |
| 362 | + |
| 363 | + if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): |
| 364 | + return [tensor] |
| 365 | + |
| 366 | + max_shape = torch.tensor(shapes).amax(dim=1) |
| 367 | + padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() |
| 368 | + padded_tensor = torch.nn.functional.pad( |
| 369 | + tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) |
| 370 | + ) |
| 371 | + all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0) |
| 372 | + return [ |
| 373 | + all_padded_tensors[ |
| 374 | + [ |
| 375 | + slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size) |
| 376 | + for dim, dim_size in enumerate(shape) |
| 377 | + ] |
| 378 | + ] |
| 379 | + for rank, shape in enumerate(shapes) |
| 380 | + if group is None or rank in group |
| 381 | + ] |
| 382 | + |
| 383 | + |
353 | 384 | def all_gather( |
354 | | - tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None |
355 | | -) -> Union[torch.Tensor, float, List[float], List[str]]: |
| 385 | + tensor: Union[torch.Tensor, float, str], |
| 386 | + group: Optional[Union[Any, List[int]]] = None, |
| 387 | + tensor_different_shape: bool = False, |
| 388 | +) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]: |
356 | 389 | """Helper method to perform all gather operation. |
357 | 390 |
|
358 | 391 | Args: |
359 | | - tensor: tensor or number or str to collect across participating processes. |
| 392 | + tensor: tensor or number or str to collect across participating processes. If tensor, it should have |
| 393 | + the same number of dimensions across processes. |
360 | 394 | group: list of integer or the process group for each backend. If None, the default process group will be used. |
| 395 | + tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it |
| 396 | + induces more collective operations. If False, `tensor` should have the same shape across processes. |
| 397 | + Ignored when `tensor` is not a tensor. Default False. |
| 398 | +
|
361 | 399 |
|
362 | 400 | Returns: |
363 | | - torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or |
364 | | - torch.Tensor of shape ``(world_size, )`` if input is a number or |
365 | | - List of strings if input is a string |
| 401 | + If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` |
| 402 | + if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group`` |
| 403 | + is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only |
| 404 | + item is retured. |
| 405 | + If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings |
| 406 | + is returned if input is a string. |
366 | 407 |
|
367 | 408 | .. versionchanged:: 0.4.11 |
368 | 409 | added ``group`` |
| 410 | +
|
| 411 | + .. versionchanged:: 0.5.1 |
| 412 | + added ``tensor_different_shape`` |
369 | 413 | """ |
370 | 414 | if _need_to_sync and isinstance(_model, _SerialModel): |
371 | 415 | sync(temporary=True) |
372 | 416 |
|
373 | 417 | if isinstance(group, list) and all(isinstance(item, int) for item in group): |
374 | 418 | group = _model.new_group(group) |
375 | 419 |
|
| 420 | + if isinstance(tensor, torch.Tensor) and tensor_different_shape: |
| 421 | + if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): |
| 422 | + return [tensor] |
| 423 | + all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view( |
| 424 | + -1, len(tensor.shape) |
| 425 | + ) |
| 426 | + return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group) |
| 427 | + |
376 | 428 | return _model.all_gather(tensor, group=group) |
377 | 429 |
|
378 | 430 |
|
|
0 commit comments