diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 183138eea9..873408b7d4 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -1394,6 +1394,11 @@ def _indent(s_, num_spaces): def _format_basic_types(k, v, use_mapping=False): if isinstance(v, str): v_str = repr(v) + elif isinstance(v, type): + if v.__module__ == 'builtins': + v_str = v.__name__ + else: + v_str = f'{v.__module__}.{v.__name__}' else: v_str = str(v) @@ -1425,6 +1430,12 @@ def _format_list_tuple(k, v, use_mapping=False): v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 elif isinstance(item, str): v_str += f'{_indent(repr(item), indent)},\n' + elif isinstance(item, type): + if item.__module__ == 'builtins': + item_str = item.__name__ + else: + item_str = f'{item.__module__}.{item.__name__}' + v_str += f'{_indent(item_str, indent)},\n' else: v_str += str(item) + ',\n' if k is None: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 9b7b64e2c1..dc5f8b57f6 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -9,6 +9,7 @@ import warnings from collections import OrderedDict from functools import partial +from inspect import isclass from typing import Callable, Dict, List, Optional, Sequence, Union import torch @@ -902,8 +903,13 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) + else: + if not isclass(model_wrapper_type): + raise TypeError('type should be a string or a class') + default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore diff --git a/tests/data/config/lazy_module_config/pure_python_style_toy_config.py b/tests/data/config/lazy_module_config/pure_python_style_toy_config.py new file mode 100644 index 0000000000..660a8550e0 --- /dev/null +++ b/tests/data/config/lazy_module_config/pure_python_style_toy_config.py @@ -0,0 +1,68 @@ +# Copyright (c) VBTI. All rights reserved. +from torch.optim import SGD +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, + IterTimerHook, LoggerHook, ParamSchedulerHook, + RuntimeInfoHook) +from mmengine.model import MMDistributedDataParallel +from mmengine.optim import MultiStepLR, OptimWrapper +import importlib.util +import os + +# Dynamically load the test module by file path to avoid relative import +# issues when the config is parsed during tests. +_mod_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), '..', '..', '..', 'test_runner', 'test_runner.py')) +spec = importlib.util.spec_from_file_location('test_runner_testmod', _mod_path) +_mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(_mod) +ToyDataset = _mod.ToyDataset +ToyModel = _mod.ToyModel +ToyMetric1 = _mod.ToyMetric1 + +# Clean up temporary variables +del importlib +del os +del spec +del _mod +del _mod_path + +model=dict(type=ToyModel) +train_dataloader=dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=True), + batch_size=3, + num_workers=0) +val_dataloader=dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=False), + batch_size=3, + num_workers=0) +test_dataloader=dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=False), + batch_size=3, + num_workers=0) +auto_scale_lr=dict(base_batch_size=16, enable=False) +optim_wrapper=dict( + type=OptimWrapper, optimizer=dict(type=SGD, lr=0.01)) +model_wrapper_cfg=dict(type=MMDistributedDataParallel) +param_scheduler=dict(type=MultiStepLR, milestones=[1, 2]) +val_evaluator=dict(type=ToyMetric1) +test_evaluator=dict(type=ToyMetric1) +train_cfg=dict( + by_epoch=True, max_epochs=3, val_interval=1, val_begin=1) +val_cfg=dict() +test_cfg=dict() +custom_hooks=[] +default_hooks=dict( + runtime_info=dict(type=RuntimeInfoHook), + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict( + type=CheckpointHook, interval=1, by_epoch=True), + sampler_seed=dict(type=DistSamplerSeedHook)) +data_preprocessor=None +launcher = 'pytorch' +env_cfg=dict(dist_cfg=dict(backend='nccl')) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index a2ef07b713..f6ef3b6c7b 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -11,8 +11,10 @@ import torch.distributed as torch_dist import mmengine.dist as dist +from mmengine.config import Config from mmengine.device import is_musa_available from mmengine.dist.dist import sync_random_seed +from mmengine.runner import Runner from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -362,7 +364,7 @@ def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = os.environ['RANK'] = str(rank) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) @@ -656,3 +658,13 @@ def test_all_reduce_params(self): for item1, item2 in zip(data_gen, expected): self.assertTrue(torch.allclose(item1, item2)) + + def test_build_runner_pure_python_style(self): + self._init_dist_env(self.rank, self.world_size) + cfg = Config.fromfile( + osp.join( + osp.dirname(__file__), '..', 'data', 'config', + 'lazy_module_config', 'pure_python_style_toy_config.py')) + cfg.work_dir = tempfile.mkdtemp() + cfg.experiment_name = 'test_build_runner_pure_python_style_config' + Runner.from_cfg(cfg)