Skip to content

Commit 9ef0785

Browse files
MGAMZHAOCHENYE
authored andcommitted
Correct logical error in package_utils and enhance test.
1 parent e1fc178 commit 9ef0785

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

mmengine/utils/package_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import os.path as osp
33
import subprocess
4-
from typing import Any
54
from importlib.metadata import PackageNotFoundError, distribution
6-
5+
from typing import Any
76

87

98
def is_installed(package: str) -> bool:
@@ -18,7 +17,7 @@ def is_installed(package: str) -> bool:
1817
spec = importlib.util.find_spec(package)
1918
if spec is not None and spec.origin is not None:
2019
return True
21-
20+
2221
# If not found as module, check if it's a distribution package
2322
try:
2423
distribution(package)
@@ -43,7 +42,6 @@ def get_installed_path(package: str) -> str:
4342
# inferred. For example, mmcv-full is the package name, but mmcv is module
4443
# name. If we want to get the installed path of mmcv-full, we should concat
4544
# the pkg.location and module name
46-
4745
# Try to get location from distribution package metadata
4846
location = None
4947
try:
@@ -52,7 +50,7 @@ def get_installed_path(package: str) -> str:
5250
location = str(locate_result.parent)
5351
except PackageNotFoundError:
5452
pass
55-
53+
5654
# If distribution package not found, try to find via importlib
5755
if location is None:
5856
spec = importlib.util.find_spec(package)
@@ -88,11 +86,12 @@ def package2module(package: str) -> str:
8886
# In importlib.metadata,
8987
# top-level modules are in dist.read_text('top_level.txt')
9088
top_level_text = dist.read_text('top_level.txt')
91-
if top_level_text is None:
92-
module_name = top_level_text.split('\n')[0]
93-
return module_name
94-
else:
95-
raise ValueError(f'can not infer the module name of {package}')
89+
if top_level_text is not None:
90+
lines = top_level_text.strip().split('\n')
91+
if lines:
92+
module_name = lines[0].strip()
93+
return module_name
94+
raise ValueError(f'can not infer the module name of {package}')
9695

9796

9897
def call_command(cmd: list) -> None:

tests/test_utils/test_package_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import os.path as osp
33
import sys
4-
54
from importlib.metadata import PackageNotFoundError
65

76
import pytest
@@ -21,6 +20,12 @@ def test_is_installed():
2120
assert is_installed('optim')
2221
sys.path.pop()
2322

23+
assert is_installed('nonexistentpackage12345') is False
24+
assert is_installed('os') is True # 'os' is a module name
25+
assert is_installed('setuptools') is True
26+
# Should work on both distribution and module name
27+
assert is_installed('pillow') is True and is_installed('PIL') is True
28+
2429

2530
def test_get_install_path():
2631
# TODO: Windows CI may failed in unknown reason. Skip check the value

0 commit comments

Comments
 (0)