Skip to content

Commit 34a1a3f

Browse files
Keep only cocomap-related changes
i.e. ObjectDetectionMap and its dependencies
1 parent f61364f commit 34a1a3f

File tree

11 files changed

+2347
-13
lines changed

11 files changed

+2347
-13
lines changed

docs/source/metrics.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,13 @@ Complete list of metrics
329329
Frequency
330330
Loss
331331
MeanAbsoluteError
332+
MeanAveragePrecision
332333
MeanPairwiseDistance
333334
MeanSquaredError
334335
metric.Metric
335336
metrics_lambda.MetricsLambda
336337
MultiLabelConfusionMatrix
338+
ObjectDetectionMAP
337339
precision.Precision
338340
PSNR
339341
recall.Recall

ignite/distributed/utils.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import itertools
12
import socket
23
from contextlib import contextmanager
34
from functools import wraps
4-
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
5+
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
56

67
import torch
78

@@ -350,29 +351,80 @@ def all_reduce(
350351
return _model.all_reduce(tensor, op, group=group)
351352

352353

354+
def _all_gather_tensors_with_shapes(
355+
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
356+
) -> List[torch.Tensor]:
357+
if _need_to_sync and isinstance(_model, _SerialModel):
358+
sync(temporary=True)
359+
360+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
361+
group = _model.new_group(group)
362+
363+
if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
364+
return [tensor]
365+
366+
max_shape = torch.tensor(shapes).amax(dim=1)
367+
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
368+
padded_tensor = torch.nn.functional.pad(
369+
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
370+
)
371+
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0)
372+
return [
373+
all_padded_tensors[
374+
[
375+
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
376+
for dim, dim_size in enumerate(shape)
377+
]
378+
]
379+
for rank, shape in enumerate(shapes)
380+
if group is None or rank in group
381+
]
382+
383+
353384
def all_gather(
354-
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
355-
) -> Union[torch.Tensor, float, List[float], List[str]]:
385+
tensor: Union[torch.Tensor, float, str],
386+
group: Optional[Union[Any, List[int]]] = None,
387+
tensor_different_shape: bool = False,
388+
) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]:
356389
"""Helper method to perform all gather operation.
357390
358391
Args:
359-
tensor: tensor or number or str to collect across participating processes.
392+
tensor: tensor or number or str to collect across participating processes. If tensor, it should have
393+
the same number of dimensions across processes.
360394
group: list of integer or the process group for each backend. If None, the default process group will be used.
395+
tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it
396+
induces more collective operations. If False, `tensor` should have the same shape across processes.
397+
Ignored when `tensor` is not a tensor. Default False.
398+
361399
362400
Returns:
363-
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
364-
torch.Tensor of shape ``(world_size, )`` if input is a number or
365-
List of strings if input is a string
401+
If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)``
402+
if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group``
403+
is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only
404+
item is retured.
405+
If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings
406+
is returned if input is a string.
366407
367408
.. versionchanged:: 0.4.11
368409
added ``group``
410+
411+
.. versionchanged:: 0.5.1
412+
added ``tensor_different_shape``
369413
"""
370414
if _need_to_sync and isinstance(_model, _SerialModel):
371415
sync(temporary=True)
372416

373417
if isinstance(group, list) and all(isinstance(item, int) for item in group):
374418
group = _model.new_group(group)
375419

420+
if isinstance(tensor, torch.Tensor) and tensor_different_shape:
421+
if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
422+
return [tensor]
423+
all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view(
424+
-1, len(tensor.shape)
425+
)
426+
return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group)
427+
376428
return _model.all_gather(tensor, group=group)
377429

378430

ignite/metrics/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ignite.metrics.gan.inception_score import InceptionScore
1010
from ignite.metrics.loss import Loss
1111
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
12+
from ignite.metrics.mean_average_precision import MeanAveragePrecision
1213
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
1314
from ignite.metrics.mean_squared_error import MeanSquaredError
1415
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
@@ -23,6 +24,7 @@
2324
from ignite.metrics.running_average import RunningAverage
2425
from ignite.metrics.ssim import SSIM
2526
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
27+
from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP
2628

2729
__all__ = [
2830
"Metric",
@@ -58,4 +60,6 @@
5860
"Rouge",
5961
"RougeN",
6062
"RougeL",
63+
"MeanAveragePrecision",
64+
"ObjectDetectionMAP",
6165
]

0 commit comments

Comments
 (0)