From 0ec825de766a35a142c2f0b6ba6c58f3ff291c9c Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 26 Oct 2025 16:50:18 +0800 Subject: [PATCH 1/7] Replace `pkg_resources` with `importlib` --- mmengine/utils/package_utils.py | 63 ++++++++++++++++---------- tests/test_utils/test_package_utils.py | 5 +- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 452bbaddaa..c2b155844c 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess +from typing import Any +from importlib.metadata import PackageNotFoundError, distribution + def is_installed(package: str) -> bool: @@ -9,21 +12,16 @@ def is_installed(package: str) -> bool: Args: package (str): Name of package to be checked. """ - # When executing `import mmengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. + # Use importlib.metadata instead of deprecated pkg_resources + # importlib.metadata is available in Python 3.8+ + # For Python 3.7, importlib_metadata backport can be used import importlib.util - import pkg_resources # type: ignore - from pkg_resources import get_distribution - - # refresh the pkg_resources - # more datails at https://github.com/pypa/setuptools/issues/373 - importlib.reload(pkg_resources) try: - get_distribution(package) + distribution(package) return True - except pkg_resources.DistributionNotFound: + except Exception: + # If distribution not found, check if module can be imported spec = importlib.util.find_spec(package) if spec is None: return False @@ -45,15 +43,31 @@ def get_installed_path(package: str) -> str: """ import importlib.util - from pkg_resources import DistributionNotFound, get_distribution - # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name try: - pkg = get_distribution(package) - except DistributionNotFound as e: + dist = distribution(package) + # In importlib.metadata, we use dist.locate_file() or files + if hasattr(dist, 'locate_file'): + # Python 3.9+ + # locate_file returns PathLike, need to access parent + locate_result: Any = dist.locate_file('') + location = str(locate_result.parent) + elif hasattr(dist, '_path'): + # Python 3.8 - _path is a pathlib.Path object + # We know _path exists because we checked with hasattr + dist_any: Any = dist + location = str(dist_any._path.parent) # type: ignore[attr-defined] + else: + # Fallback: try to find via importlib + spec = importlib.util.find_spec(package) + if spec is not None and spec.origin is not None: + return osp.dirname(spec.origin) + raise RuntimeError( + f'Cannot determine installation path for {package}') + except PackageNotFoundError as e: # if the package is not installed, package path set in PYTHONPATH # can be detected by `find_spec` spec = importlib.util.find_spec(package) @@ -69,23 +83,26 @@ def get_installed_path(package: str) -> str: else: raise e - possible_path = osp.join(pkg.location, package) # type: ignore + possible_path = osp.join(location, package) if osp.exists(possible_path): return possible_path else: - return osp.join(pkg.location, package2module(package)) # type: ignore + return osp.join(location, package2module(package)) -def package2module(package: str): +def package2module(package: str) -> str: """Infer module name from package. Args: package (str): Package to infer module name. """ - from pkg_resources import get_distribution - pkg = get_distribution(package) - if pkg.has_metadata('top_level.txt'): - module_name = pkg.get_metadata('top_level.txt').split('\n')[0] + dist = distribution(package) + + # In importlib.metadata, + # top-level modules are in dist.read_text('top_level.txt') + top_level_text = dist.read_text('top_level.txt') + if top_level_text: + module_name = top_level_text.split('\n')[0] return module_name else: raise ValueError(f'can not infer the module name of {package}') @@ -100,4 +117,4 @@ def call_command(cmd: list) -> None: def install_package(package: str): if not is_installed(package): - call_command(['python', '-m', 'pip', 'install', package]) + call_command(['python', '-m', 'pip', 'install', package]) \ No newline at end of file diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index 11ce294c29..de43597d16 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -2,7 +2,8 @@ import os.path as osp import sys -import pkg_resources # type: ignore +from importlib.metadata import PackageNotFoundError + import pytest from mmengine.utils import get_installed_path, is_installed @@ -33,5 +34,5 @@ def test_get_install_path(): assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim') sys.path.pop() - with pytest.raises(pkg_resources.DistributionNotFound): + with pytest.raises(PackageNotFoundError): get_installed_path('unknown') From 07d46ce84b645dc26f5596d4864cd04116cb91d3 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 26 Oct 2025 16:54:56 +0800 Subject: [PATCH 2/7] [Improve] Limit the excepetion caught field. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mmengine/utils/package_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index c2b155844c..36f72cc19b 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -20,7 +20,7 @@ def is_installed(package: str) -> bool: try: distribution(package) return True - except Exception: + except PackageNotFoundError: # If distribution not found, check if module can be imported spec = importlib.util.find_spec(package) if spec is None: From 01bf4cabd64e8ea0adcc69aa04527555c520fc66 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 26 Oct 2025 20:26:25 +0800 Subject: [PATCH 3/7] 1. Accepting Google's guideline. 2. Only keep `locate_file` branch. 3. keep the logic of `distribution package` and `import package` independent. --- mmengine/utils/package_utils.py | 55 +++++++++++++-------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 36f72cc19b..fd7ad95e23 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -12,23 +12,19 @@ def is_installed(package: str) -> bool: Args: package (str): Name of package to be checked. """ - # Use importlib.metadata instead of deprecated pkg_resources - # importlib.metadata is available in Python 3.8+ - # For Python 3.7, importlib_metadata backport can be used import importlib.util + # First check if it's an importable module + spec = importlib.util.find_spec(package) + if spec is not None and spec.origin is not None: + return True + + # If not found as module, check if it's a distribution package try: distribution(package) return True except PackageNotFoundError: - # If distribution not found, check if module can be imported - spec = importlib.util.find_spec(package) - if spec is None: - return False - elif spec.origin is not None: - return True - else: - return False + return False def get_installed_path(package: str) -> str: @@ -47,29 +43,18 @@ def get_installed_path(package: str) -> str: # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name + + # Try to get location from distribution package metadata + location = None try: dist = distribution(package) - # In importlib.metadata, we use dist.locate_file() or files - if hasattr(dist, 'locate_file'): - # Python 3.9+ - # locate_file returns PathLike, need to access parent - locate_result: Any = dist.locate_file('') - location = str(locate_result.parent) - elif hasattr(dist, '_path'): - # Python 3.8 - _path is a pathlib.Path object - # We know _path exists because we checked with hasattr - dist_any: Any = dist - location = str(dist_any._path.parent) # type: ignore[attr-defined] - else: - # Fallback: try to find via importlib - spec = importlib.util.find_spec(package) - if spec is not None and spec.origin is not None: - return osp.dirname(spec.origin) - raise RuntimeError( - f'Cannot determine installation path for {package}') - except PackageNotFoundError as e: - # if the package is not installed, package path set in PYTHONPATH - # can be detected by `find_spec` + locate_result: Any = dist.locate_file('') + location = str(locate_result.parent) + except PackageNotFoundError: + pass + + # If distribution package not found, try to find via importlib + if location is None: spec = importlib.util.find_spec(package) if spec is not None: if spec.origin is not None: @@ -81,8 +66,10 @@ def get_installed_path(package: str) -> str: f'{package} is a namespace package, which is invalid ' 'for `get_install_path`') else: - raise e + raise PackageNotFoundError( + f'Package {package} is not installed') + # Check if package directory exists in the location possible_path = osp.join(location, package) if osp.exists(possible_path): return possible_path @@ -101,7 +88,7 @@ def package2module(package: str) -> str: # In importlib.metadata, # top-level modules are in dist.read_text('top_level.txt') top_level_text = dist.read_text('top_level.txt') - if top_level_text: + if top_level_text is None: module_name = top_level_text.split('\n')[0] return module_name else: From 262e224a6d270bfc91919a857ac6c21154435dcc Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 26 Oct 2025 20:28:45 +0800 Subject: [PATCH 4/7] Add an empty line at the end. --- mmengine/utils/package_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index fd7ad95e23..842a730e2b 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -104,4 +104,4 @@ def call_command(cmd: list) -> None: def install_package(package: str): if not is_installed(package): - call_command(['python', '-m', 'pip', 'install', package]) \ No newline at end of file + call_command(['python', '-m', 'pip', 'install', package]) From 7ab3e3dd01700881429256677f29c07c402a9515 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Mon, 27 Oct 2025 09:16:01 +0800 Subject: [PATCH 5/7] Correct logical error in `package_utils` and enhance test. --- mmengine/utils/package_utils.py | 19 +++++++++---------- tests/test_utils/test_package_utils.py | 7 ++++++- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 842a730e2b..3d1cef004e 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess -from typing import Any from importlib.metadata import PackageNotFoundError, distribution - +from typing import Any def is_installed(package: str) -> bool: @@ -18,7 +17,7 @@ def is_installed(package: str) -> bool: spec = importlib.util.find_spec(package) if spec is not None and spec.origin is not None: return True - + # If not found as module, check if it's a distribution package try: distribution(package) @@ -43,7 +42,6 @@ def get_installed_path(package: str) -> str: # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name - # Try to get location from distribution package metadata location = None try: @@ -52,7 +50,7 @@ def get_installed_path(package: str) -> str: location = str(locate_result.parent) except PackageNotFoundError: pass - + # If distribution package not found, try to find via importlib if location is None: spec = importlib.util.find_spec(package) @@ -88,11 +86,12 @@ def package2module(package: str) -> str: # In importlib.metadata, # top-level modules are in dist.read_text('top_level.txt') top_level_text = dist.read_text('top_level.txt') - if top_level_text is None: - module_name = top_level_text.split('\n')[0] - return module_name - else: - raise ValueError(f'can not infer the module name of {package}') + if top_level_text is not None: + lines = top_level_text.strip().split('\n') + if lines: + module_name = lines[0].strip() + return module_name + raise ValueError(f'can not infer the module name of {package}') def call_command(cmd: list) -> None: diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index de43597d16..276d514dfc 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import sys - from importlib.metadata import PackageNotFoundError import pytest @@ -21,6 +20,12 @@ def test_is_installed(): assert is_installed('optim') sys.path.pop() + assert is_installed('nonexistentpackage12345') is False + assert is_installed('os') is True # 'os' is a module name + assert is_installed('setuptools') is True + # Should work on both distribution and module name + assert is_installed('pillow') is True and is_installed('PIL') is True + def test_get_install_path(): # TODO: Windows CI may failed in unknown reason. Skip check the value From 6fdb2f91179c0bd660c729366f9197aed7cb1cc4 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Mon, 27 Oct 2025 09:27:34 +0800 Subject: [PATCH 6/7] Lint using yapf. --- mmengine/_strategy/deepspeed.py | 8 +++---- mmengine/config/config.py | 21 +++++++++++-------- mmengine/dataset/utils.py | 3 ++- mmengine/fileio/backends/local_backend.py | 4 ++-- mmengine/fileio/file_client.py | 4 ++-- mmengine/hooks/checkpoint_hook.py | 8 +++---- mmengine/model/test_time_aug.py | 7 ++++--- mmengine/runner/checkpoint.py | 7 ++++--- mmengine/utils/dl_utils/torch_ops.py | 6 +++--- mmengine/utils/package_utils.py | 3 +-- mmengine/visualization/visualizer.py | 5 +++-- .../lazy_module_config/test_ast_transform.py | 1 + .../lazy_module_config/test_mix_builtin.py | 1 - .../config/lazy_module_config/toy_model.py | 1 - .../config/py_config/test_custom_class.py | 1 + .../py_config/test_dump_pickle_support.py | 2 +- .../py_config/test_get_external_cfg3.py | 9 ++------ tests/test_analysis/test_jit_analysis.py | 7 ++++--- tests/test_dataset/test_base_dataset.py | 8 +++---- .../test_optimizer/test_optimizer_wrapper.py | 12 +++++------ 20 files changed, 60 insertions(+), 58 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..3d945a6a54 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -310,10 +310,10 @@ def __init__( self.config.setdefault('gradient_accumulation_steps', 1) self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.13.2') - ), ('DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') + assert (exclude_frozen_parameters is None or digit_version( + deepspeed.__version__) >= digit_version('0.13.2')), ( + 'DeepSpeed >= 0.13.2 is required to enable ' + 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters register_deepspeed_optimizers() diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 801243c82d..c460ae8e3f 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -46,9 +46,10 @@ def _lazy2string(cfg_dict, dict_type=None): if isinstance(cfg_dict, dict): dict_type = dict_type or type(cfg_dict) - return dict_type( - {k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict)}) + return dict_type({ + k: _lazy2string(v, dict_type) + for k, v in dict.items(cfg_dict) + }) elif isinstance(cfg_dict, (tuple, list)): return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) elif isinstance(cfg_dict, (LazyAttr, LazyObject)): @@ -271,13 +272,15 @@ def __reduce_ex__(self, proto): # called by CPython interpreter during pickling. See more details in # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None) else: - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None, None) def __eq__(self, other): if isinstance(other, ConfigDict): diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index 2c9cf96497..d140cc8dc4 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any: return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ - key: default_collate([d[key] for d in data_batch]) + key: + default_collate([d[key] for d in data_batch]) for key in data_item }) else: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index c7d5f04621..84ebe95514 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 61551d3d1d..29730e7564 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 92a4867bb9..3adb78c7dc 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -196,10 +196,10 @@ def __init__(self, self.save_best = save_best # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) - or (rule is None)), ( - '"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') + assert (isinstance(rule, str) or is_list_of(rule, str) or + (rule + is None)), ('"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py index c623eec8bc..2f19248c2c 100644 --- a/mmengine/model/test_time_aug.py +++ b/mmengine/model/test_time_aug.py @@ -124,9 +124,10 @@ def test_step(self, data): data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) - data_list = [{key: value[idx] - for key, value in data.items()} - for idx in range(num_augs)] + data_list = [{ + key: value[idx] + for key, value in data.items() + } for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index d55e6d6c3a..f061cf5cac 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -601,9 +601,10 @@ def _load_checkpoint_to_model(model, # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) + state_dict = OrderedDict({ + re.sub(p, r, k): v + for k, v in state_dict.items() + }) # Keep metadata in state_dict state_dict._metadata = metadata diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py index 2550ae6986..85dc3100d2 100644 --- a/mmengine/utils/dl_utils/torch_ops.py +++ b/mmengine/utils/dl_utils/torch_ops.py @@ -4,9 +4,9 @@ from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) +_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) + >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 3d1cef004e..606d3686c3 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -64,8 +64,7 @@ def get_installed_path(package: str) -> str: f'{package} is a namespace package, which is invalid ' 'for `get_install_path`') else: - raise PackageNotFoundError( - f'Package {package} is not installed') + raise PackageNotFoundError(f'Package {package} is not installed') # Check if package directory exists in the location possible_path = osp.join(location, package) diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 6979395aca..6653497d6e 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -754,8 +754,9 @@ def draw_bboxes( assert bboxes.shape[-1] == 4, ( f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= - bboxes[:, 3]).all() + assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] + <= bboxes[:, + 3]).all() if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): warnings.warn( 'Warning: The bbox is out of bounds,' diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py index a8803dde24..141ec0304e 100644 --- a/tests/data/config/lazy_module_config/test_ast_transform.py +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -12,4 +12,5 @@ from ._base_.default_runtime import default_scope as scope from ._base_.scheduler import val_cfg from rich.progress import Progress + start = Progress.start diff --git a/tests/data/config/lazy_module_config/test_mix_builtin.py b/tests/data/config/lazy_module_config/test_mix_builtin.py index e36da58a3b..1698f1ea4b 100644 --- a/tests/data/config/lazy_module_config/test_mix_builtin.py +++ b/tests/data/config/lazy_module_config/test_mix_builtin.py @@ -13,4 +13,3 @@ chained = list(chain([1, 2], [3, 4])) existed = ex(__file__) cfgname = partial(basename, __file__)() - diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py index a9d2a3f64a..99755b4525 100644 --- a/tests/data/config/lazy_module_config/toy_model.py +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -13,7 +13,6 @@ param_scheduler.milestones = [2, 4] - train_dataloader = dict( dataset=dict(type=ToyDataset), sampler=dict(type=DefaultSampler, shuffle=True), diff --git a/tests/data/config/py_config/test_custom_class.py b/tests/data/config/py_config/test_custom_class.py index ad706b087e..ae6af19e25 100644 --- a/tests/data/config/py_config/test_custom_class.py +++ b/tests/data/config/py_config/test_custom_class.py @@ -2,4 +2,5 @@ class A: ... + item_a = dict(a=A) diff --git a/tests/data/config/py_config/test_dump_pickle_support.py b/tests/data/config/py_config/test_dump_pickle_support.py index 6050ce10b1..2f8dafa3aa 100644 --- a/tests/data/config/py_config/test_dump_pickle_support.py +++ b/tests/data/config/py_config/test_dump_pickle_support.py @@ -24,5 +24,5 @@ def func(): dict_item5 = {'x/x': {'a.0': 233}} dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]} # Test windows path and escape. -str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in +str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in str_item_8 = func() diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py index 5ae261350a..2dded0da76 100644 --- a/tests/data/config/py_config/test_get_external_cfg3.py +++ b/tests/data/config/py_config/test_get_external_cfg3.py @@ -3,16 +3,11 @@ 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', 'mmdet::_base_/datasets/coco_detection.py', 'mmdet::_base_/schedules/schedule_1x.py', - 'mmdet::_base_/default_runtime.py', - './test_get_external_cfg_base.py' + 'mmdet::_base_/default_runtime.py', './test_get_external_cfg_base.py' ] custom_hooks = [dict(type='mmdet.DetVisualizationHook')] model = dict( roi_head=dict( - bbox_head=dict( - loss_cls=dict(_delete_=True, type='test.ToyLoss') - ) - ) -) + bbox_head=dict(loss_cls=dict(_delete_=True, type='test.ToyLoss')))) diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index be10309d0f..4b1dfaf595 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -634,9 +634,10 @@ def dummy_ops_handle(inputs: List[Any], dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter( - {op: flop - for op, flop in counts.items() if op != self.lin_op}) + dummy_flops[name] = Counter({ + op: flop + for op, flop in counts.items() if op != self.lin_op + }) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index f4ec815ec2..48bba665fe 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -733,13 +733,13 @@ def test_length(self): def test_getitem(self): assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] != - self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] + != self.dataset_b[0]['imgs']).all() assert ( self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] != - self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] + != self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): assert self.cat_datasets.get_data_info( diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..8a6e57d456 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -455,8 +455,8 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -478,8 +478,8 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -539,8 +539,8 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): From 17d8c169118c3d89beb6a632eac09a6c3be2c4a5 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Mon, 27 Oct 2025 09:34:53 +0800 Subject: [PATCH 7/7] Lint using yapf=0.32.0 on python 3.10.19 --- mmengine/_strategy/deepspeed.py | 8 +++---- mmengine/config/config.py | 21 ++++++++----------- mmengine/dataset/utils.py | 3 +-- mmengine/fileio/backends/local_backend.py | 4 ++-- mmengine/fileio/file_client.py | 4 ++-- mmengine/hooks/checkpoint_hook.py | 8 +++---- mmengine/model/test_time_aug.py | 7 +++---- mmengine/runner/checkpoint.py | 7 +++---- mmengine/utils/dl_utils/torch_ops.py | 6 +++--- mmengine/visualization/visualizer.py | 5 ++--- tests/test_analysis/test_jit_analysis.py | 7 +++---- tests/test_dataset/test_base_dataset.py | 8 +++---- .../test_optimizer/test_optimizer_wrapper.py | 12 +++++------ 13 files changed, 46 insertions(+), 54 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3d945a6a54..3f89ff760d 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -310,10 +310,10 @@ def __init__( self.config.setdefault('gradient_accumulation_steps', 1) self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or digit_version( - deepspeed.__version__) >= digit_version('0.13.2')), ( - 'DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') + assert (exclude_frozen_parameters is None or + digit_version(deepspeed.__version__) >= digit_version('0.13.2') + ), ('DeepSpeed >= 0.13.2 is required to enable ' + 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters register_deepspeed_optimizers() diff --git a/mmengine/config/config.py b/mmengine/config/config.py index c460ae8e3f..801243c82d 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -46,10 +46,9 @@ def _lazy2string(cfg_dict, dict_type=None): if isinstance(cfg_dict, dict): dict_type = dict_type or type(cfg_dict) - return dict_type({ - k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict) - }) + return dict_type( + {k: _lazy2string(v, dict_type) + for k, v in dict.items(cfg_dict)}) elif isinstance(cfg_dict, (tuple, list)): return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) elif isinstance(cfg_dict, (LazyAttr, LazyObject)): @@ -272,15 +271,13 @@ def __reduce_ex__(self, proto): # called by CPython interpreter during pickling. See more details in # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({ - k: v - for k, v in super().items() - }, ), None, None, None) + return (self.__class__, ({k: v + for k, v in super().items()}, ), None, + None, None) else: - return (self.__class__, ({ - k: v - for k, v in super().items() - }, ), None, None, None, None) + return (self.__class__, ({k: v + for k, v in super().items()}, ), None, + None, None, None) def __eq__(self, other): if isinstance(other, ConfigDict): diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index d140cc8dc4..2c9cf96497 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -158,8 +158,7 @@ def default_collate(data_batch: Sequence) -> Any: return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ - key: - default_collate([d[key] for d in data_batch]) + key: default_collate([d[key] for d in data_batch]) for key in data_item }) else: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index 84ebe95514..c7d5f04621 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], *filepaths: - Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 29730e7564..61551d3d1d 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], *filepaths: - Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 3adb78c7dc..92a4867bb9 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -196,10 +196,10 @@ def __init__(self, self.save_best = save_best # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) or - (rule - is None)), ('"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') + assert (isinstance(rule, str) or is_list_of(rule, str) + or (rule is None)), ( + '"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py index 2f19248c2c..c623eec8bc 100644 --- a/mmengine/model/test_time_aug.py +++ b/mmengine/model/test_time_aug.py @@ -124,10 +124,9 @@ def test_step(self, data): data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) - data_list = [{ - key: value[idx] - for key, value in data.items() - } for idx in range(num_augs)] + data_list = [{key: value[idx] + for key, value in data.items()} + for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index f061cf5cac..d55e6d6c3a 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -601,10 +601,9 @@ def _load_checkpoint_to_model(model, # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict({ - re.sub(p, r, k): v - for k, v in state_dict.items() - }) + state_dict = OrderedDict( + {re.sub(p, r, k): v + for k, v in state_dict.items()}) # Keep metadata in state_dict state_dict._metadata = metadata diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py index 85dc3100d2..2550ae6986 100644 --- a/mmengine/utils/dl_utils/torch_ops.py +++ b/mmengine/utils/dl_utils/torch_ops.py @@ -4,9 +4,9 @@ from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION -_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) - >= digit_version('1.10.0a0')) +_torch_version_meshgrid_indexing = ( + 'parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 6653497d6e..6979395aca 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -754,9 +754,8 @@ def draw_bboxes( assert bboxes.shape[-1] == 4, ( f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] - <= bboxes[:, - 3]).all() + assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= + bboxes[:, 3]).all() if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): warnings.warn( 'Warning: The bbox is out of bounds,' diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index 4b1dfaf595..be10309d0f 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -634,10 +634,9 @@ def dummy_ops_handle(inputs: List[Any], dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter({ - op: flop - for op, flop in counts.items() if op != self.lin_op - }) + dummy_flops[name] = Counter( + {op: flop + for op, flop in counts.items() if op != self.lin_op}) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index 48bba665fe..f4ec815ec2 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -733,13 +733,13 @@ def test_length(self): def test_getitem(self): assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] - != self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] != + self.dataset_b[0]['imgs']).all() assert ( self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] - != self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] != + self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): assert self.cat_datasets.get_data_info( diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 8a6e57d456..ef1db241dd 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -455,8 +455,8 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) - < digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -478,8 +478,8 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) - < digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -539,8 +539,8 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) - < digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported():