|
30 | 30 | import platform |
31 | 31 | import stat |
32 | 32 | import sys |
33 | | -import threading |
34 | 33 | import time |
35 | 34 | import typing # noqa: F401 # pylint: disable=unused-import |
36 | 35 | import warnings |
|
55 | 54 | logging.getLogger("iso8601").setLevel(logging.WARNING) |
56 | 55 |
|
57 | 56 |
|
| 57 | +class RetryOnExceptions(tenacity.retry_if_exception): # type: ignore |
| 58 | + """Advanced retry on exceptions.""" |
| 59 | + |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + retry_on, # type: typing.Union[typing.Type[BaseException], typing.Tuple[typing.Type[BaseException], ...]] |
| 63 | + reraise, # type: typing.Union[typing.Type[BaseException], typing.Tuple[typing.Type[BaseException], ...]] |
| 64 | + ): # type: (...) -> None |
| 65 | + """Retry on exceptions, except several types.""" |
| 66 | + super(RetryOnExceptions, self).__init__(lambda e: isinstance(e, retry_on) and not isinstance(e, reraise)) |
| 67 | + |
| 68 | + |
58 | 69 | class SshExecuteAsyncResult(api.ExecuteAsyncResult): |
59 | 70 | """Override original NamedTuple with proper typing.""" |
60 | 71 |
|
@@ -356,7 +367,7 @@ def _ssh(self): # type: () -> paramiko.SSHClient |
356 | 367 | return self.__ssh |
357 | 368 |
|
358 | 369 | @tenacity.retry( # type: ignore |
359 | | - retry=tenacity.retry_if_exception_type(paramiko.SSHException), |
| 370 | + retry=RetryOnExceptions(retry_on=paramiko.SSHException, reraise=paramiko.AuthenticationException), |
360 | 371 | stop=tenacity.stop_after_attempt(3), |
361 | 372 | wait=tenacity.wait_fixed(3), |
362 | 373 | reraise=True, |
@@ -587,7 +598,14 @@ def execute_async( |
587 | 598 |
|
588 | 599 | if stdin is not None: |
589 | 600 | if not _stdin.channel.closed: |
590 | | - _stdin.write("{stdin}\n".format(stdin=stdin)) |
| 601 | + if isinstance(stdin, bytes): |
| 602 | + stdin_str = stdin.decode("utf-8") |
| 603 | + elif isinstance(stdin, bytearray): |
| 604 | + stdin_str = bytes(stdin).decode("utf-8") |
| 605 | + else: |
| 606 | + stdin_str = stdin |
| 607 | + |
| 608 | + _stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8")) |
591 | 609 | _stdin.flush() |
592 | 610 | else: |
593 | 611 | self.logger.warning("STDIN Send failed: closed channel") |
@@ -632,45 +650,37 @@ def poll_streams(): # type: () -> None |
632 | 650 | result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) |
633 | 651 |
|
634 | 652 | @threaded.threadpooled |
635 | | - def poll_pipes(stop,): # type: (threading.Event) -> None |
636 | | - """Polling task for FIFO buffers. |
637 | | -
|
638 | | - :type stop: Event |
639 | | - """ |
640 | | - while not stop.is_set(): |
| 653 | + def poll_pipes(): # type: () -> None |
| 654 | + """Polling task for FIFO buffers.""" |
| 655 | + while not async_result.interface.status_event.is_set(): |
641 | 656 | time.sleep(0.1) |
642 | 657 | if async_result.stdout or async_result.stderr: |
643 | 658 | poll_streams() |
644 | 659 |
|
645 | | - if async_result.interface.status_event.is_set(): |
646 | | - result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose) |
647 | | - result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) |
648 | | - result.exit_code = async_result.interface.exit_status |
649 | | - |
650 | | - stop.set() |
| 660 | + result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose) |
| 661 | + result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) |
| 662 | + result.exit_code = async_result.interface.exit_status |
651 | 663 |
|
652 | 664 | # channel.status_event.wait(timeout) |
653 | 665 | cmd_for_log = self._mask_command(cmd=command, log_mask_re=log_mask_re) |
654 | 666 |
|
655 | 667 | # Store command with hidden data |
656 | 668 | result = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin")) |
657 | 669 |
|
658 | | - stop_event = threading.Event() |
659 | | - |
660 | 670 | # pylint: disable=assignment-from-no-return |
661 | 671 | # noinspection PyNoneFunctionAssignment |
662 | | - future = poll_pipes(stop=stop_event) # type: concurrent.futures.Future |
| 672 | + future = poll_pipes() # type: concurrent.futures.Future |
663 | 673 | # pylint: enable=assignment-from-no-return |
664 | 674 |
|
665 | 675 | concurrent.futures.wait([future], timeout) |
666 | 676 |
|
667 | 677 | # Process closed? |
668 | | - if stop_event.is_set(): |
| 678 | + if async_result.interface.status_event.is_set(): |
669 | 679 | async_result.interface.close() |
670 | 680 | return result |
671 | 681 |
|
672 | | - stop_event.set() |
673 | 682 | async_result.interface.close() |
| 683 | + async_result.interface.status_event.set() |
674 | 684 | future.cancel() |
675 | 685 |
|
676 | 686 | wait_err_msg = _log_templates.CMD_WAIT_ERROR.format(result=result, timeout=timeout) |
@@ -806,7 +816,7 @@ def get_result(remote): # type: (SSHClientBase) -> exec_result.ExecResult |
806 | 816 | cmd_for_log = remote._mask_command(cmd=command, log_mask_re=kwargs.get("log_mask_re", None)) |
807 | 817 | # pylint: enable=protected-access |
808 | 818 |
|
809 | | - res = exec_result.ExecResult(cmd=cmd_for_log) |
| 819 | + res = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin", None)) |
810 | 820 | res.read_stdout(src=async_result.stdout) |
811 | 821 | res.read_stderr(src=async_result.stderr) |
812 | 822 | res.exit_code = exit_code |
|
0 commit comments