Skip to content

Commit 397aaf3

Browse files
add ddpallreducemean
1 parent be448ba commit 397aaf3

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

ssseg/modules/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
BuildScheduler, BuildParamsConstructor, BuildNormalization, BuildActivation, BuildDropout, EMASegmentor
1212
)
1313
from .utils import (
14-
initslurm, setrandomseed, touchdir, loadckpts, saveckpts, loadpretrainedweights, symlink, judgefileexist, postprocesspredgtpairs, ismainprocess, touchdirs,
15-
LoggerHandleBuilder, BuildLoggerHandle, TrainingLoggingManager, BaseModuleBuilder, ConfigParser, EnvironmentCollector, SSSegInputStructure,
16-
SSSegOutputStructure,
14+
initslurm, setrandomseed, touchdir, loadckpts, saveckpts, loadpretrainedweights, symlink, judgefileexist, postprocesspredgtpairs, ismainprocess, touchdirs, ddpallreducemean,
15+
LoggerHandleBuilder, TrainingLoggingManager, BaseModuleBuilder, ConfigParser, EnvironmentCollector, SSSegInputStructure, SSSegOutputStructure, BuildLoggerHandle,
1716
)

ssseg/modules/models/segmentors/ocrnet/ocrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, data_meta):
5656
# feed to bottleneck
5757
feats = self.bottleneck(backbone_outputs[-1])
5858
# feed to ocr module
59-
context = self.spatial_gather_module(feats, seg_logits_aux)
59+
context = self.spatial_gather_module(feats, F.interpolate(seg_logits_aux, size=feats.shape[-2:], mode='bilinear', align_corners=self.align_corners))
6060
feats = self.object_context_block(feats, context)
6161
# feed to decoder
6262
seg_logits = self.decoder(feats)

ssseg/modules/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from .configparser import ConfigParser
55
from .modulebuilder import BaseModuleBuilder
66
from .datastructure import SSSegInputStructure, SSSegOutputStructure
7-
from .misc import setrandomseed, postprocesspredgtpairs, ismainprocess
87
from .logger import LoggerHandleBuilder, BuildLoggerHandle, TrainingLoggingManager
8+
from .misc import setrandomseed, postprocesspredgtpairs, ismainprocess, ddpallreducemean
99
from .io import touchdir, loadckpts, saveckpts, loadpretrainedweights, symlink, judgefileexist, touchdirs

ssseg/modules/utils/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
from .io import touchdirs
1414

1515

16+
'''ddpallreducemean'''
17+
def ddpallreducemean(tensor: torch.Tensor) -> torch.Tensor:
18+
world_size = dist.get_world_size()
19+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
20+
tensor /= world_size
21+
return tensor
22+
23+
1624
'''setrandomseed'''
1725
def setrandomseed(seed):
1826
random.seed(seed)

0 commit comments

Comments
 (0)