Skip to content

Commit e168685

Browse files
committed
Extra typing information for usage in IDE/ipython
1 parent fb76879 commit e168685

File tree

6 files changed

+331
-60
lines changed

6 files changed

+331
-60
lines changed

exec_helpers/_ssh_client_base.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
_SSHConnChainT = typing.List[typing.Tuple[SSHConfig, ssh_auth.SSHAuth]]
6464
_OptionalTimeoutT = typing.Union[int, float, None]
6565
_OptionalStdinT = typing.Union[bytes, str, bytearray, None]
66+
_ExitCodeT = typing.Union[int, proc_enums.ExitCodes]
6667

6768

6869
class RetryOnExceptions(tenacity.retry_if_exception): # type: ignore
@@ -78,7 +79,7 @@ def __init__(
7879
:param retry_on: Exceptions to retry on
7980
:param reraise: Exceptions, which should be reraised, even if subclasses retry_on
8081
"""
81-
super(RetryOnExceptions, self).__init__(lambda e: isinstance(e, retry_on) and not isinstance(e, reraise))
82+
super().__init__(lambda e: isinstance(e, retry_on) and not isinstance(e, reraise))
8283

8384

8485
# noinspection PyTypeHints
@@ -88,25 +89,25 @@ class SshExecuteAsyncResult(api.ExecuteAsyncResult):
8889
@property
8990
def interface(self) -> paramiko.Channel:
9091
"""Override original NamedTuple with proper typing."""
91-
return super(SshExecuteAsyncResult, self).interface
92+
return super().interface
9293

9394
@property
9495
def stdin(self) -> paramiko.ChannelFile: # type: ignore
9596
"""Override original NamedTuple with proper typing."""
96-
return super(SshExecuteAsyncResult, self).stdin
97+
return super().stdin
9798

9899
@property
99100
def stderr(self) -> typing.Optional[paramiko.ChannelFile]: # type: ignore
100101
"""Override original NamedTuple with proper typing."""
101-
return super(SshExecuteAsyncResult, self).stderr
102+
return super().stderr
102103

103104
@property
104105
def stdout(self) -> typing.Optional[paramiko.ChannelFile]: # type: ignore
105106
"""Override original NamedTuple with proper typing."""
106-
return super(SshExecuteAsyncResult, self).stdout
107+
return super().stdout
107108

108109

109-
class _SudoContext:
110+
class _SudoContext(typing.ContextManager[None]):
110111
"""Context manager for call commands with sudo."""
111112

112113
__slots__ = ("__ssh", "__sudo_status", "__enforce")
@@ -132,7 +133,7 @@ def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any
132133
self.__ssh.sudo_mode = self.__sudo_status
133134

134135

135-
class _KeepAliveContext:
136+
class _KeepAliveContext(typing.ContextManager[None]):
136137
"""Context manager for keepalive management."""
137138

138139
__slots__ = ("__ssh", "__keepalive_status", "__enforce")
@@ -146,12 +147,12 @@ def __init__(self, ssh: "SSHClientBase", enforce: int) -> None:
146147
:type enforce: int
147148
"""
148149
self.__ssh: "SSHClientBase" = ssh
149-
self.__keepalive_status: int = ssh.keepalive_mode
150+
self.__keepalive_status: int = ssh.keepalive_period
150151
self.__enforce: int = enforce
151152

152153
def __enter__(self) -> None:
153154
self.__ssh.__enter__()
154-
self.__keepalive_status = self.__ssh.keepalive_mode
155+
self.__keepalive_status = self.__ssh.keepalive_period
155156
self.__ssh.keepalive_mode = self.__enforce
156157

157158
def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any) -> None:
@@ -524,7 +525,7 @@ def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any
524525
"""
525526
if not self.__keepalive_period:
526527
self.close()
527-
super(SSHClientBase, self).__exit__(exc_type, exc_val, exc_tb)
528+
super().__exit__(exc_type, exc_val, exc_tb)
528529

529530
@property
530531
def sudo_mode(self) -> bool:
@@ -930,7 +931,7 @@ def check_call( # pylint: disable=arguments-differ
930931
verbose: bool = False,
931932
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
932933
error_info: typing.Optional[str] = None,
933-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
934+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
934935
raise_on_err: bool = True,
935936
*,
936937
log_mask_re: typing.Optional[str] = None,
@@ -1011,7 +1012,7 @@ def check_stderr( # pylint: disable=arguments-differ
10111012
error_info: typing.Optional[str] = None,
10121013
raise_on_err: bool = True,
10131014
*,
1014-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
1015+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
10151016
log_mask_re: typing.Optional[str] = None,
10161017
stdin: _OptionalStdinT = None,
10171018
open_stdout: bool = True,
@@ -1262,7 +1263,7 @@ def execute_together(
12621263
remotes: typing.Iterable["SSHClientBase"],
12631264
command: str,
12641265
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
1265-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
1266+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
12661267
raise_on_err: bool = True,
12671268
*,
12681269
stdin: _OptionalStdinT = None,
@@ -1349,9 +1350,7 @@ def get_result(remote: "SSHClientBase") -> exec_result.ExecResult:
13491350
async_result.interface.close()
13501351
return res
13511352

1352-
prep_expected: typing.Tuple[typing.Union[int, proc_enums.ExitCodes], ...] = proc_enums.exit_codes_to_enums(
1353-
expected
1354-
)
1353+
prep_expected: typing.Sequence[_ExitCodeT] = proc_enums.exit_codes_to_enums(expected)
13551354
log_level: int = logging.INFO if verbose else logging.DEBUG
13561355

13571356
futures: typing.Dict["SSHClientBase", "concurrent.futures.Future[exec_result.ExecResult]"] = {

exec_helpers/api.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@
4949
)
5050
_OptionalTimeoutT = typing.Union[int, float, None]
5151
_OptionalStdinT = typing.Union[bytes, str, bytearray, None]
52+
_ExitCodeT = typing.Union[int, proc_enums.ExitCodes]
5253

5354

5455
# noinspection PyProtectedMember
55-
class _ChRootContext:
56+
class _ChRootContext(typing.ContextManager[None]):
5657
"""Context manager for call commands with chroot.
5758
5859
.. versionadded:: 4.1.0
@@ -444,7 +445,7 @@ def check_call(
444445
verbose: bool = False,
445446
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
446447
error_info: typing.Optional[str] = None,
447-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
448+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
448449
raise_on_err: bool = True,
449450
*,
450451
log_mask_re: typing.Optional[str] = None,
@@ -490,9 +491,7 @@ def check_call(
490491
.. versionchanged:: 3.2.0 Exception class can be substituted
491492
.. versionchanged:: 3.4.0 Expected is not optional, defaults os dependent
492493
"""
493-
expected_codes: typing.Tuple[typing.Union[int, proc_enums.ExitCodes], ...] = proc_enums.exit_codes_to_enums(
494-
expected
495-
)
494+
expected_codes: typing.Sequence[_ExitCodeT] = proc_enums.exit_codes_to_enums(expected)
496495
result: exec_result.ExecResult = self.execute(
497496
command,
498497
verbose=verbose,
@@ -522,7 +521,7 @@ def check_stderr(
522521
error_info: typing.Optional[str] = None,
523522
raise_on_err: bool = True,
524523
*,
525-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
524+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
526525
log_mask_re: typing.Optional[str] = None,
527526
stdin: _OptionalStdinT = None,
528527
open_stdout: bool = True,
@@ -593,7 +592,7 @@ def _handle_stderr(
593592
result: exec_result.ExecResult,
594593
error_info: typing.Optional[str],
595594
raise_on_err: bool,
596-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]],
595+
expected: typing.Iterable[_ExitCodeT],
597596
exception_class: "typing.Type[exceptions.CalledProcessError]",
598597
) -> exec_result.ExecResult:
599598
"""Internal check_stderr logic (synchronous)."""

exec_helpers/async_api/api.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@
3434
from exec_helpers import proc_enums
3535

3636

37+
_OptionalTimeoutT = typing.Union[int, float, None]
38+
_OptionalStdinT = typing.Union[bytes, str, bytearray, None]
39+
_ExitCodeT = typing.Union[int, proc_enums.ExitCodes]
40+
41+
3742
# noinspection PyProtectedMember
38-
class _ChRootContext(api._ChRootContext): # pylint: disable=protected-access
43+
class _ChRootContext(api._ChRootContext, typing.AsyncContextManager[None]): # pylint: disable=protected-access
3944
"""Async extension for chroot."""
4045

4146
def __init__(self, conn: "ExecHelper", path: typing.Optional[typing.Union[str, pathlib.Path]] = None) -> None:
@@ -46,7 +51,7 @@ def __init__(self, conn: "ExecHelper", path: typing.Optional[typing.Union[str, p
4651
:param path: chroot path or None for no chroot
4752
:type path: typing.Optional[typing.Union[str, pathlib.Path]]
4853
"""
49-
super(_ChRootContext, self).__init__(conn=conn, path=path)
54+
super().__init__(conn=conn, path=path)
5055

5156
async def __aenter__(self) -> None:
5257
await self._conn.__aenter__()
@@ -72,7 +77,7 @@ def __init__(self, log_mask_re: typing.Optional[str] = None, *, logger: logging.
7277
all MATCHED groups will be replaced by '<*masked*>'
7378
:type log_mask_re: typing.Optional[str]
7479
"""
75-
super(ExecHelper, self).__init__(logger=logger, log_mask_re=log_mask_re)
80+
super().__init__(logger=logger, log_mask_re=log_mask_re)
7681
self.__alock: typing.Optional[asyncio.Lock] = None
7782

7883
def __enter__(self) -> "ExecHelper": # pylint: disable=useless-super-delegation
@@ -108,11 +113,11 @@ async def _exec_command( # type: ignore
108113
self,
109114
command: str,
110115
async_result: api.ExecuteAsyncResult,
111-
timeout: typing.Union[int, float, None],
116+
timeout: _OptionalTimeoutT,
112117
*,
113118
verbose: bool = False,
114119
log_mask_re: typing.Optional[str] = None,
115-
stdin: typing.Union[bytes, str, bytearray, None] = None,
120+
stdin: _OptionalStdinT = None,
116121
**kwargs: typing.Any,
117122
) -> exec_result.ExecResult:
118123
"""Get exit status from channel with timeout.
@@ -181,10 +186,10 @@ async def execute( # type: ignore
181186
self,
182187
command: str,
183188
verbose: bool = False,
184-
timeout: typing.Union[int, float, None] = constants.DEFAULT_TIMEOUT,
189+
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
185190
*,
186191
log_mask_re: typing.Optional[str] = None,
187-
stdin: typing.Union[bytes, str, bytearray, None] = None,
192+
stdin: _OptionalStdinT = None,
188193
open_stdout: bool = True,
189194
open_stderr: bool = True,
190195
**kwargs: typing.Any,
@@ -247,10 +252,10 @@ async def __call__( # type: ignore
247252
self,
248253
command: str,
249254
verbose: bool = False,
250-
timeout: typing.Union[int, float, None] = constants.DEFAULT_TIMEOUT,
255+
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
251256
*,
252257
log_mask_re: typing.Optional[str] = None,
253-
stdin: typing.Union[bytes, str, bytearray, None] = None,
258+
stdin: _OptionalStdinT = None,
254259
open_stdout: bool = True,
255260
open_stderr: bool = True,
256261
**kwargs: typing.Any,
@@ -295,13 +300,13 @@ async def check_call( # type: ignore
295300
self,
296301
command: str,
297302
verbose: bool = False,
298-
timeout: typing.Union[int, float, None] = constants.DEFAULT_TIMEOUT,
303+
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
299304
error_info: typing.Optional[str] = None,
300-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
305+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
301306
raise_on_err: bool = True,
302307
*,
303308
log_mask_re: typing.Optional[str] = None,
304-
stdin: typing.Union[bytes, str, bytearray, None] = None,
309+
stdin: _OptionalStdinT = None,
305310
open_stdout: bool = True,
306311
open_stderr: bool = True,
307312
exception_class: "typing.Type[exceptions.CalledProcessError]" = exceptions.CalledProcessError,
@@ -341,9 +346,7 @@ async def check_call( # type: ignore
341346
342347
.. versionchanged:: 3.4.0 Expected is not optional, defaults os dependent
343348
"""
344-
expected_codes: typing.Tuple[typing.Union[int, proc_enums.ExitCodes], ...] = proc_enums.exit_codes_to_enums(
345-
expected
346-
)
349+
expected_codes: typing.Sequence[_ExitCodeT] = proc_enums.exit_codes_to_enums(expected)
347350
result: exec_result.ExecResult = await self.execute(
348351
command,
349352
verbose=verbose,
@@ -369,13 +372,13 @@ async def check_stderr( # type: ignore
369372
self,
370373
command: str,
371374
verbose: bool = False,
372-
timeout: typing.Union[int, float, None] = constants.DEFAULT_TIMEOUT,
375+
timeout: _OptionalTimeoutT = constants.DEFAULT_TIMEOUT,
373376
error_info: typing.Optional[str] = None,
374377
raise_on_err: bool = True,
375378
*,
376-
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]] = (proc_enums.EXPECTED,),
379+
expected: typing.Iterable[_ExitCodeT] = (proc_enums.EXPECTED,),
377380
log_mask_re: typing.Optional[str] = None,
378-
stdin: typing.Union[bytes, str, bytearray, None] = None,
381+
stdin: _OptionalStdinT = None,
379382
open_stdout: bool = True,
380383
open_stderr: bool = True,
381384
exception_class: "typing.Type[exceptions.CalledProcessError]" = exceptions.CalledProcessError,

0 commit comments

Comments
 (0)