Skip to content

Commit 9eefbf6

Browse files
committed
Backport fixes from master branch
Fix ExecResult compare: allow subclassing. Re-raise auth errors ASAP. (#99) * Make tests faster. * Bump to 3.1.1 (cherry picked from commit 2a64bbf) Fix SSHClient: stdin processing, missed command in result for parallel (#100) Port part of tests to pytest: better use-cases coverage (cherry picked from commit dec689b)
1 parent 93dd7eb commit 9eefbf6

19 files changed

+2249
-2733
lines changed

exec_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"ExecResult",
5252
)
5353

54-
__version__ = "1.9.1"
54+
__version__ = "1.9.2"
5555
__author__ = "Alexey Stepanov"
5656
__author_email__ = "penguinolog@gmail.com"
5757
__maintainers__ = {

exec_helpers/_ssh_client_base.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import platform
3131
import stat
3232
import sys
33-
import threading
3433
import time
3534
import typing # noqa: F401 # pylint: disable=unused-import
3635
import warnings
@@ -55,6 +54,18 @@
5554
logging.getLogger("iso8601").setLevel(logging.WARNING)
5655

5756

57+
class RetryOnExceptions(tenacity.retry_if_exception): # type: ignore
58+
"""Advanced retry on exceptions."""
59+
60+
def __init__(
61+
self,
62+
retry_on, # type: typing.Union[typing.Type[BaseException], typing.Tuple[typing.Type[BaseException], ...]]
63+
reraise, # type: typing.Union[typing.Type[BaseException], typing.Tuple[typing.Type[BaseException], ...]]
64+
): # type: (...) -> None
65+
"""Retry on exceptions, except several types."""
66+
super(RetryOnExceptions, self).__init__(lambda e: isinstance(e, retry_on) and not isinstance(e, reraise))
67+
68+
5869
class SshExecuteAsyncResult(api.ExecuteAsyncResult):
5970
"""Override original NamedTuple with proper typing."""
6071

@@ -356,7 +367,7 @@ def _ssh(self): # type: () -> paramiko.SSHClient
356367
return self.__ssh
357368

358369
@tenacity.retry( # type: ignore
359-
retry=tenacity.retry_if_exception_type(paramiko.SSHException),
370+
retry=RetryOnExceptions(retry_on=paramiko.SSHException, reraise=paramiko.AuthenticationException),
360371
stop=tenacity.stop_after_attempt(3),
361372
wait=tenacity.wait_fixed(3),
362373
reraise=True,
@@ -587,7 +598,14 @@ def execute_async(
587598

588599
if stdin is not None:
589600
if not _stdin.channel.closed:
590-
_stdin.write("{stdin}\n".format(stdin=stdin))
601+
if isinstance(stdin, bytes):
602+
stdin_str = stdin.decode("utf-8")
603+
elif isinstance(stdin, bytearray):
604+
stdin_str = bytes(stdin).decode("utf-8")
605+
else:
606+
stdin_str = stdin
607+
608+
_stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8"))
591609
_stdin.flush()
592610
else:
593611
self.logger.warning("STDIN Send failed: closed channel")
@@ -632,45 +650,37 @@ def poll_streams(): # type: () -> None
632650
result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose)
633651

634652
@threaded.threadpooled
635-
def poll_pipes(stop,): # type: (threading.Event) -> None
636-
"""Polling task for FIFO buffers.
637-
638-
:type stop: Event
639-
"""
640-
while not stop.is_set():
653+
def poll_pipes(): # type: () -> None
654+
"""Polling task for FIFO buffers."""
655+
while not async_result.interface.status_event.is_set():
641656
time.sleep(0.1)
642657
if async_result.stdout or async_result.stderr:
643658
poll_streams()
644659

645-
if async_result.interface.status_event.is_set():
646-
result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose)
647-
result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose)
648-
result.exit_code = async_result.interface.exit_status
649-
650-
stop.set()
660+
result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose)
661+
result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose)
662+
result.exit_code = async_result.interface.exit_status
651663

652664
# channel.status_event.wait(timeout)
653665
cmd_for_log = self._mask_command(cmd=command, log_mask_re=log_mask_re)
654666

655667
# Store command with hidden data
656668
result = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin"))
657669

658-
stop_event = threading.Event()
659-
660670
# pylint: disable=assignment-from-no-return
661671
# noinspection PyNoneFunctionAssignment
662-
future = poll_pipes(stop=stop_event) # type: concurrent.futures.Future
672+
future = poll_pipes() # type: concurrent.futures.Future
663673
# pylint: enable=assignment-from-no-return
664674

665675
concurrent.futures.wait([future], timeout)
666676

667677
# Process closed?
668-
if stop_event.is_set():
678+
if async_result.interface.status_event.is_set():
669679
async_result.interface.close()
670680
return result
671681

672-
stop_event.set()
673682
async_result.interface.close()
683+
async_result.interface.status_event.set()
674684
future.cancel()
675685

676686
wait_err_msg = _log_templates.CMD_WAIT_ERROR.format(result=result, timeout=timeout)
@@ -806,7 +816,7 @@ def get_result(remote): # type: (SSHClientBase) -> exec_result.ExecResult
806816
cmd_for_log = remote._mask_command(cmd=command, log_mask_re=kwargs.get("log_mask_re", None))
807817
# pylint: enable=protected-access
808818

809-
res = exec_result.ExecResult(cmd=cmd_for_log)
819+
res = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin", None))
810820
res.read_stdout(src=async_result.stdout)
811821
res.read_stderr(src=async_result.stderr)
812822
res.exit_code = exit_code

exec_helpers/exec_result.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,17 @@ def __str__(self): # type: () -> str
468468

469469
def __eq__(self, other): # type: (typing.Any) -> bool
470470
"""Comparision."""
471-
return hash(self) == hash(other)
471+
return (
472+
self.__class__ is other.__class__
473+
or issubclass(self.__class__, other.__class__)
474+
or issubclass(other.__class__, self.__class__)
475+
) and (
476+
self.cmd == other.cmd
477+
and self.stdin == other.stdin
478+
and self.stdout == other.stdout
479+
and self.stderr == other.stderr
480+
and self.exit_code == other.exit_code
481+
)
472482

473483
def __ne__(self, other): # type: (typing.Any) -> bool
474484
"""Comparision."""

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[pytest]
22
addopts = -vvv -s -p no:django -p no:ipdb
33
testpaths = test
4+
mock_use_standalone_module = true

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ exclude =
3939
__init__.py,
4040
docs
4141
ignore =
42+
# line break before binary operator
43+
W503
4244
show-pep8 = True
4345
show-source = True
4446
count = True

0 commit comments

Comments
 (0)