Skip to content

Commit a131498

Browse files
committed
Fix ssh stdin processing: do not add newline, support on proxy
(cherry picked from commit 0786fc9) Signed-off-by: Alexey Stepanov <penguinolog@gmail.com>
1 parent 9eefbf6 commit a131498

File tree

7 files changed

+67
-24
lines changed

7 files changed

+67
-24
lines changed

exec_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"ExecResult",
5252
)
5353

54-
__version__ = "1.9.2"
54+
__version__ = "1.9.3"
5555
__author__ = "Alexey Stepanov"
5656
__author_email__ = "penguinolog@gmail.com"
5757
__maintainers__ = {

exec_helpers/_ssh_client_base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -598,14 +598,9 @@ def execute_async(
598598

599599
if stdin is not None:
600600
if not _stdin.channel.closed:
601-
if isinstance(stdin, bytes):
602-
stdin_str = stdin.decode("utf-8")
603-
elif isinstance(stdin, bytearray):
604-
stdin_str = bytes(stdin).decode("utf-8")
605-
else:
606-
stdin_str = stdin
607-
608-
_stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8"))
601+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
602+
603+
_stdin.write(stdin_str)
609604
_stdin.flush()
610605
else:
611606
self.logger.warning("STDIN Send failed: closed channel")
@@ -752,12 +747,23 @@ def execute_through_host(
752747
)
753748

754749
# Make proxy objects for read
755-
stdout = channel.makefile("rb")
756-
stderr = channel.makefile_stderr("rb")
750+
_stdin = channel.makefile("wb") # type: paramiko.ChannelFile
751+
stdout = channel.makefile("rb") # type: paramiko.ChannelFile
752+
stderr = channel.makefile_stderr("rb") # type: paramiko.ChannelFile
757753

758754
channel.exec_command(command) # nosec # Sanitize on caller side
759755

760-
async_result = SshExecuteAsyncResult(interface=channel, stdin=None, stdout=stdout, stderr=stderr)
756+
stdin = kwargs.get("stdin", None)
757+
if stdin is not None:
758+
if not _stdin.channel.closed:
759+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
760+
761+
_stdin.write(stdin_str)
762+
_stdin.flush()
763+
else:
764+
self.logger.warning("STDIN Send failed: closed channel")
765+
766+
async_result = SshExecuteAsyncResult(interface=channel, stdin=_stdin, stdout=stdout, stderr=stderr)
761767

762768
# noinspection PyDictCreation
763769
result = self._exec_command(
@@ -766,6 +772,7 @@ def execute_through_host(
766772
timeout=timeout,
767773
verbose=verbose,
768774
log_mask_re=kwargs.get("log_mask_re", None),
775+
stdin=stdin,
769776
)
770777

771778
intermediate_channel.close()

exec_helpers/api.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, logger, log_mask_re=None): # type: (logging.Logger, typing.O
6363
:type log_mask_re: typing.Optional[str]
6464
6565
.. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd
66-
.. versionchanged:: 1.3.5 make API public paramikoto use as interface
66+
.. versionchanged:: 1.3.5 make API public to use as interface
6767
"""
6868
self.__lock = threading.RLock()
6969
self.__logger = logger
@@ -281,18 +281,18 @@ def check_call(
281281
282282
.. versionchanged:: 1.2.0 default timeout 1 hour
283283
"""
284-
expected = proc_enums.exit_codes_to_enums(expected)
284+
expected_codes = proc_enums.exit_codes_to_enums(expected)
285285
ret = self.execute(command, verbose, timeout, **kwargs)
286-
if ret.exit_code not in expected:
286+
if ret.exit_code not in expected_codes:
287287
message = (
288288
"{append}Command {result.cmd!r} returned exit code "
289289
"{result.exit_code!s} while expected {expected!s}".format(
290-
append=error_info + "\n" if error_info else "", result=ret, expected=expected
290+
append=error_info + "\n" if error_info else "", result=ret, expected=expected_codes
291291
)
292292
)
293293
self.logger.error(msg=message)
294294
if raise_on_err:
295-
raise exceptions.CalledProcessError(result=ret, expected=expected)
295+
raise exceptions.CalledProcessError(result=ret, expected=expected_codes)
296296
return ret
297297

298298
def check_stderr(
@@ -337,3 +337,21 @@ def check_stderr(
337337
if raise_on_err:
338338
raise exceptions.CalledProcessError(result=ret, expected=kwargs.get("expected"))
339339
return ret
340+
341+
@staticmethod
342+
def _string_bytes_bytearray_as_bytes(src): # type: (typing.Union[six.text_type, bytes, bytearray]) -> bytes
343+
"""Get bytes string from string/bytes/bytearray union.
344+
345+
:return: Byte string
346+
:rtype: bytes
347+
:raises TypeError: unexpected source type.
348+
"""
349+
if isinstance(src, bytes):
350+
return src
351+
if isinstance(src, bytearray):
352+
return bytes(src)
353+
if isinstance(src, six.text_type):
354+
return src.encode("utf-8")
355+
raise TypeError( # pragma: no cover
356+
"{!r} has unexpected type: not conform to Union[str, bytes, bytearray]".format(src)
357+
)

exec_helpers/exec_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def __deserialize(self, fmt): # type: (str) -> typing.Any
385385
try:
386386
if fmt == "json":
387387
return json.loads(self.stdout_str, encoding="utf-8")
388-
elif fmt == "yaml":
388+
if fmt == "yaml":
389389
return yaml.safe_load(self.stdout_str)
390390
except Exception:
391391
tmpl = " stdout is not valid {fmt}:\n" "{{stdout!r}}\n".format(fmt=fmt)

exec_helpers/subprocess_runner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,9 @@ def execute_async(
225225
if stdin is None:
226226
process_stdin = process.stdin
227227
else:
228-
if isinstance(stdin, six.text_type):
229-
stdin = stdin.encode(encoding="utf-8")
230-
elif isinstance(stdin, bytearray):
231-
stdin = bytes(stdin)
228+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
232229
try:
233-
process.stdin.write(stdin)
230+
process.stdin.write(stdin_str)
234231
except OSError as exc:
235232
if exc.errno == errno.EINVAL:
236233
# bpo-19612, bpo-30418: On Windows, stdin.write() fails

test/test_ssh_client_execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan
342342
assert res.stdin.channel == res.interface
343343

344344
if stdin:
345-
res.stdin.write.assert_called_with("{stdin}\n".format(stdin=stdin).encode("utf-8"))
345+
res.stdin.write.assert_called_with(stdin.encode("utf-8"))
346346
res.stdin.flush.assert_called_once()
347347
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
348348
log.log.assert_called_once_with(level=logging.DEBUG, msg=command_log)

test/test_ssh_client_execute_throw_host.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,24 @@ def test_03_execute_get_pty(ssh, ssh_transport_channel):
186186
target = "127.0.0.2"
187187
ssh.execute_through_host(target, command, get_pty=True)
188188
ssh_transport_channel.get_pty.assert_called_with(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0)
189+
190+
191+
def test_04_execute_use_stdin(ssh, chan_makefile):
192+
target = "127.0.0.2"
193+
cmd = 'read line; echo "$line"'
194+
stdin = "test"
195+
res = ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
196+
assert res.stdin == stdin
197+
chan_makefile.stdin.write.assert_called_once_with(stdin.encode("utf-8"))
198+
chan_makefile.stdin.flush.assert_called_once()
199+
200+
201+
def test_05_execute_closed_stdin(ssh, ssh_transport_channel, get_logger):
202+
target = "127.0.0.2"
203+
cmd = 'read line; echo "$line"'
204+
stdin = "test"
205+
ssh_transport_channel.closed = True
206+
207+
ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
208+
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
209+
log.warning.assert_called_once_with("STDIN Send failed: closed channel")

0 commit comments

Comments
 (0)