Skip to content

Commit 43673f3

Browse files
authored
Fixing lints and warning from utils file (#344)
1 parent dad4372 commit 43673f3

File tree

3 files changed

+73
-70
lines changed

3 files changed

+73
-70
lines changed

dbldatagen/spark_singleton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SparkSingleton:
1919
"""A singleton class which returns one Spark session instance"""
2020

2121
@classmethod
22-
def getInstance(cls) -> SparkSession:
22+
def getInstance(cls: type["SparkSingleton"]) -> SparkSession:
2323
"""Creates a `SparkSession` instance for Datalib.
2424
2525
:returns: A Spark instance
@@ -28,7 +28,7 @@ def getInstance(cls) -> SparkSession:
2828
return SparkSession.builder.getOrCreate()
2929

3030
@classmethod
31-
def getLocalInstance(cls, appName: str = "new Spark session", useAllCores: bool = True) -> SparkSession:
31+
def getLocalInstance(cls: type["SparkSingleton"], appName: str = "new Spark session", useAllCores: bool = True) -> SparkSession:
3232
"""Creates a machine local `SparkSession` instance for Datalib.
3333
By default, it uses `n-1` cores of the available cores for the spark session,
3434
where `n` is total cores available.

dbldatagen/utils.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import re
1414
import time
1515
import warnings
16+
from collections.abc import Callable
1617
from datetime import timedelta
18+
from typing import Any
1719

1820
import jmespath
1921

2022

21-
def deprecated(message=""):
23+
def deprecated(message: str = "") -> Callable[[Callable[..., Any]], Callable[..., Any]]:
2224
"""
2325
Define a deprecated decorator without dependencies on 3rd party libraries
2426
@@ -27,12 +29,12 @@ def deprecated(message=""):
2729
"""
2830

2931
# create closure around function that follows use of the decorator
30-
def deprecated_decorator(func):
32+
def deprecated_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
3133
@functools.wraps(func)
32-
def deprecated_func(*args, **kwargs):
34+
def deprecated_func(*args: object, **kwargs: object) -> object:
3335
warnings.warn(f"`{func.__name__}` is a deprecated function or method. \n{message}",
3436
category=DeprecationWarning, stacklevel=1)
35-
warnings.simplefilter('default', DeprecationWarning)
37+
warnings.simplefilter("default", DeprecationWarning)
3638
return func(*args, **kwargs)
3739

3840
return deprecated_func
@@ -47,21 +49,21 @@ class DataGenError(Exception):
4749
:param baseException: underlying exception, if any that caused the issue
4850
"""
4951

50-
def __init__(self, msg, baseException=None):
52+
def __init__(self, msg: str, baseException: object | None = None) -> None:
5153
""" constructor
5254
"""
5355
super().__init__(msg)
54-
self._underlyingException = baseException
55-
self._msg = msg
56+
self._underlyingException: object | None = baseException
57+
self._msg: str = msg
5658

57-
def __repr__(self):
59+
def __repr__(self) -> str:
5860
return f"DataGenError(msg='{self._msg}', baseException={self._underlyingException})"
5961

60-
def __str__(self):
62+
def __str__(self) -> str:
6163
return f"DataGenError(msg='{self._msg}', baseException={self._underlyingException})"
6264

6365

64-
def coalesce_values(*args):
66+
def coalesce_values(*args: object) -> object | None:
6567
"""For a supplied list of arguments, returns the first argument that does not have the value `None`
6668
6769
:param args: variable list of arguments which are evaluated
@@ -73,7 +75,7 @@ def coalesce_values(*args):
7375
return None
7476

7577

76-
def ensure(cond, msg="condition does not hold true"):
78+
def ensure(cond: bool, msg: str = "condition does not hold true") -> None:
7779
"""ensure(cond, s) => throws Exception(s) if c is not true
7880
7981
:param cond: condition to test
@@ -82,34 +84,38 @@ def ensure(cond, msg="condition does not hold true"):
8284
:returns: Does not return anything but raises exception if condition does not hold
8385
"""
8486

85-
def strip_margin(text):
86-
return re.sub(r'\n[ \t]*\|', '\n', text)
87+
def strip_margin(text: str) -> str:
88+
return re.sub(r"\n[ \t]*\|", "\n", text)
8789

8890
if not cond:
8991
raise DataGenError(strip_margin(msg))
9092

9193

92-
def mkBoundsList(x, default):
94+
def mkBoundsList(x: int | list[int] | None, default: int | list[int]) -> tuple[bool, list[int]]:
9395
""" make a bounds list from supplied parameter - otherwise use default
9496
9597
:param x: integer or list of 2 values that define bounds list
9698
:param default: default value if X is `None`
9799
:returns: list of form [x,y]
98100
"""
99101
if x is None:
100-
retval = (True, [default, default]) if type(default) is int else (True, list(default))
102+
retval = (True, [default, default]) if isinstance(default, int) else (True, list(default))
101103
return retval
102-
elif type(x) is int:
103-
bounds_list = [x, x]
104+
elif isinstance(x, int):
105+
bounds_list: list[int] = [x, x]
104106
assert len(bounds_list) == 2, "bounds list must be of length 2"
105107
return False, bounds_list
106108
else:
107-
bounds_list = list(x)
109+
bounds_list: list[int] = list(x)
108110
assert len(bounds_list) == 2, "bounds list must be of length 2"
109111
return False, bounds_list
110112

111113

112-
def topologicalSort(sources, initial_columns=None, flatten=True):
114+
def topologicalSort(
115+
sources: list[tuple[str, set[str]]],
116+
initial_columns: list[str] | None = None,
117+
flatten: bool = True
118+
) -> list[str] | list[list[str]]:
113119
""" Perform a topological sort over sources
114120
115121
Used to compute the column test data generation order of the column generation dependencies.
@@ -129,16 +135,16 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
129135
Overall the effect is that the input build order should be retained unless there are forward references
130136
"""
131137
# generate a copy so that we can modify in place
132-
pending = [(name, set(deps)) for name, deps in sources]
133-
provided = [] if initial_columns is None else initial_columns[:]
134-
build_orders = [] if initial_columns is None else [initial_columns]
138+
pending: list[tuple[str, set[str]]] = [(name, set(deps)) for name, deps in sources]
139+
provided: list[str] = [] if initial_columns is None else initial_columns[:]
140+
build_orders: list[list[str]] = [] if initial_columns is None else [initial_columns]
135141

136142
while pending:
137-
next_pending = []
138-
gen = []
139-
value_emitted = False
140-
defer_emitted = False
141-
gen_provided = []
143+
next_pending: list[tuple[str, set[str]]] = []
144+
gen: list[str] = []
145+
value_emitted: bool = False
146+
defer_emitted: bool = False
147+
gen_provided: list[str] = []
142148
for entry in pending:
143149
name, deps = entry
144150
deps.difference_update(provided)
@@ -165,7 +171,7 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
165171
pending = next_pending
166172

167173
if flatten:
168-
flattened_list = [item for sublist in build_orders for item in sublist]
174+
flattened_list: list[str] = [item for sublist in build_orders for item in sublist]
169175
return flattened_list
170176
else:
171177
return build_orders
@@ -176,31 +182,31 @@ def topologicalSort(sources, initial_columns=None, flatten=True):
176182
_WEEKS_PER_YEAR = 52
177183

178184

179-
def parse_time_interval(spec):
185+
def parse_time_interval(spec: str) -> timedelta:
180186
"""parse time interval from string"""
181-
hours = 0
182-
minutes = 0
183-
weeks = 0
184-
microseconds = 0
185-
milliseconds = 0
186-
seconds = 0
187-
years = 0
188-
days = 0
187+
hours: int = 0
188+
minutes: int = 0
189+
weeks: int = 0
190+
microseconds: int = 0
191+
milliseconds: int = 0
192+
seconds: int = 0
193+
years: int = 0
194+
days: int = 0
189195

190196
assert spec is not None, "Must have valid time interval specification"
191197

192198
# get time specs such as 12 days, etc. Supported timespans are years, days, hours, minutes, seconds
193-
timespecs = [x.strip() for x in spec.strip().split(",")]
199+
timespecs: list[str] = [x.strip() for x in spec.strip().split(",")]
194200

195201
for ts in timespecs:
196202
# allow both 'days=1' and '1 day' syntax
197-
timespec_parts = re.findall(PATTERN_NAME_EQUALS_VALUE, ts)
203+
timespec_parts: list[tuple[str, str]] = re.findall(PATTERN_NAME_EQUALS_VALUE, ts)
198204
# findall returns list of tuples
199205
if timespec_parts is not None and len(timespec_parts) > 0:
200-
num_parts = len(timespec_parts[0])
206+
num_parts: int = len(timespec_parts[0])
201207
assert num_parts >= 1, "must have numeric specification and time element such as `12 hours` or `hours=12`"
202-
time_value = int(timespec_parts[0][num_parts - 1])
203-
time_type = timespec_parts[0][0].lower()
208+
time_value: int = int(timespec_parts[0][num_parts - 1])
209+
time_type: str = timespec_parts[0][0].lower()
204210
else:
205211
timespec_parts = re.findall(PATTERN_VALUE_SPACE_NAME, ts)
206212
num_parts = len(timespec_parts[0])
@@ -225,7 +231,7 @@ def parse_time_interval(spec):
225231
elif time_type in ["milliseconds", "millisecond"]:
226232
milliseconds = time_value
227233

228-
delta = timedelta(
234+
delta: timedelta = timedelta(
229235
days=days,
230236
seconds=seconds,
231237
microseconds=microseconds,
@@ -238,44 +244,40 @@ def parse_time_interval(spec):
238244
return delta
239245

240246

241-
def strip_margins(s, marginChar):
247+
def strip_margins(s: str, marginChar: str) -> str:
242248
"""
243249
Python equivalent of Scala stripMargins method
244-
245250
Takes a string (potentially multiline) and strips all chars up and including the first occurrence of `marginChar`.
246251
Used to control the formatting of generated text
247-
248252
`strip_margins("one\n |two\n |three", '|')`
249-
250-
will produce
251-
253+
will produce
252254
``
253-
one
255+
one
254256
two
255257
three
256258
``
257259
258260
:param s: string to strip margins from
259-
:param marginChar: character to strip
261+
:param marginChar: character to strip
260262
:return: modified string
261263
"""
262-
assert s is not None and type(s) is str
263-
assert marginChar is not None and type(marginChar) is str
264+
assert s is not None and isinstance(s, str)
265+
assert marginChar is not None and isinstance(marginChar, str)
264266

265-
lines = s.split('\n')
266-
revised_lines = []
267+
lines: list[str] = s.split("\n")
268+
revised_lines: list[str] = []
267269

268270
for line in lines:
269271
if marginChar in line:
270-
revised_line = line[line.index(marginChar) + 1:]
272+
revised_line: str = line[line.index(marginChar) + 1:]
271273
revised_lines.append(revised_line)
272274
else:
273275
revised_lines.append(line)
274276

275-
return '\n'.join(revised_lines)
277+
return "\n".join(revised_lines)
276278

277279

278-
def split_list_matching_condition(lst, cond):
280+
def split_list_matching_condition(lst: list[Any], cond: Callable[[Any], bool]) -> list[list[Any]]:
279281
"""
280282
Split a list on elements that match a condition
281283
@@ -297,9 +299,9 @@ def split_list_matching_condition(lst, cond):
297299
:arg cond: lambda function or function taking single argument and returning True or False
298300
:returns: list of sublists
299301
"""
300-
retval = []
302+
retval: list[list[Any]] = []
301303

302-
def match_condition(matchList, matchFn):
304+
def match_condition(matchList: list[Any], matchFn: Callable[[Any], bool]) -> int:
303305
"""Return first index of element of list matching condition"""
304306
if matchList is None or len(matchList) == 0:
305307
return -1
@@ -315,7 +317,7 @@ def match_condition(matchList, matchFn):
315317
elif len(lst) == 1:
316318
retval = [lst]
317319
else:
318-
ix = match_condition(lst, cond)
320+
ix: int = match_condition(lst, cond)
319321
if ix != -1:
320322
retval.extend(split_list_matching_condition(lst[0:ix], cond))
321323
retval.append(lst[ix:ix + 1])
@@ -327,7 +329,7 @@ def match_condition(matchList, matchFn):
327329
return [el for el in retval if el != []]
328330

329331

330-
def json_value_from_path(searchPath, jsonData, defaultValue):
332+
def json_value_from_path(searchPath: str, jsonData: str, defaultValue: object) -> object:
331333
""" Get JSON value from JSON data referenced by searchPath
332334
333335
searchPath should be a JSON path as supported by the `jmespath` package
@@ -341,20 +343,20 @@ def json_value_from_path(searchPath, jsonData, defaultValue):
341343
assert searchPath is not None and len(searchPath) > 0, "search path cannot be empty"
342344
assert jsonData is not None and len(jsonData) > 0, "JSON data cannot be empty"
343345

344-
jsonDict = json.loads(jsonData)
346+
jsonDict: dict = json.loads(jsonData)
345347

346-
jsonValue = jmespath.search(searchPath, jsonDict)
348+
jsonValue: Any = jmespath.search(searchPath, jsonDict)
347349

348350
if jsonValue is not None:
349351
return jsonValue
350352

351353
return defaultValue
352354

353355

354-
def system_time_millis():
356+
def system_time_millis() -> int:
355357
""" return system time as milliseconds since start of epoch
356358
357359
:return: system time millis as long
358360
"""
359-
curr_time = round(time.time() / 1000)
361+
curr_time: int = round(time.time() / 1000)
360362
return curr_time

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ exclude = [
155155
"dbldatagen/schema_parser.py",
156156
"dbldatagen/serialization.py",
157157
"dbldatagen/text_generator_plugins.py",
158-
"dbldatagen/text_generators.py",
159-
"dbldatagen/utils.py"
158+
"dbldatagen/text_generators.py"
160159
]
161160

162161
[tool.ruff.lint]
@@ -173,6 +172,7 @@ select = [
173172
"Q", # flake8-quotes
174173
"PL", # pylint
175174
"RUF", # ruff-specific rules
175+
"ANN", # ruff-flake8-annotations
176176
]
177177
ignore = [
178178
"E501", # Line too long (let ruff formatter handle this)
@@ -188,6 +188,7 @@ ignore = [
188188
"SIM102", # Use a single if-statement
189189
"SIM108", # Use ternary operator
190190
"UP007", # Use X | Y for type annotations (keep Union for compatibility)
191+
"ANN101", # Missing type annotation for `self` in method
191192
]
192193

193194
[tool.ruff.lint.per-file-ignores]

0 commit comments

Comments
 (0)