Skip to content

Commit b9b5165

Browse files
committed
refactor: add type hints in test_requirements.py
1 parent 28eecc2 commit b9b5165

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

test/test_requirements.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
HTML_DEP_CSV = str(HTML_DEP_PATH_PATHLIB / "dependencies.csv")
3030

3131

32-
def get_out_of_sync_packages(csv_name, txt_name):
33-
new_packages = set()
34-
removed_packages = set()
35-
csv_package_names = set()
36-
txt_package_names = set()
32+
def get_out_of_sync_packages(csv_name: str, txt_name: str) -> tuple[set[str], set[str]]:
33+
new_packages: set[str] = set()
34+
removed_packages: set[str] = set()
35+
csv_package_names: set[str] = set()
36+
txt_package_names: set[str] = set()
3737

3838
with open(csv_name) as csv_file, open(txt_name) as txt_file:
3939
csv_reader = csv.reader(csv_file)
@@ -46,12 +46,12 @@ def get_out_of_sync_packages(csv_name, txt_name):
4646
new_packages = txt_package_names - csv_package_names
4747
removed_packages = csv_package_names - txt_package_names
4848

49-
return (new_packages, removed_packages)
49+
return new_packages, removed_packages
5050

5151

5252
# Test to check if the requirements.csv files are in sync with requirements.txt files
53-
def test_txt_csv_sync():
54-
errors = set()
53+
def test_txt_csv_sync() -> None:
54+
errors: set[str] = set()
5555

5656
(
5757
req_new_packages,
@@ -82,8 +82,8 @@ def test_txt_csv_sync():
8282
assert errors == set(), f"The error(s) are:\n {''.join(errors)}"
8383

8484

85-
def get_cache_csv_data(file):
86-
data = []
85+
def get_cache_csv_data(file: str) -> list[tuple[str, str, str]]:
86+
data: list[tuple[str, str, str]] = []
8787

8888
with open(file) as f:
8989
r = csv.reader(f)
@@ -95,9 +95,12 @@ def get_cache_csv_data(file):
9595
file_name += ".js"
9696
with open(file_name) as f:
9797
file_content = f.read()
98-
html_dep_version = re.search(
99-
r"v([0-9]+\.[0-9]+\.[0-9]+)", file_content
100-
).group(1)
98+
match = re.search(r"v([0-9]+\.[0-9]+\.[0-9]+)", file_content)
99+
if not match:
100+
raise ValueError(
101+
f"Could not find version for {product} in {file_name}"
102+
)
103+
html_dep_version = match.group(1)
101104
data.append((vendor, product, html_dep_version))
102105
else:
103106
if "_not_in_db" not in vendor:
@@ -107,7 +110,7 @@ def get_cache_csv_data(file):
107110

108111

109112
# Test to check for CVEs in cve-bin-tool requirements/dependencies
110-
def test_requirements():
113+
def test_requirements() -> None:
111114
cache_csv_data = (
112115
get_cache_csv_data(REQ_CSV)
113116
+ get_cache_csv_data(DOC_CSV)

0 commit comments

Comments
 (0)