Skip to content

Commit 0786fc9

Browse files
committed
Fix ssh stdin processing: do not add newline, support on proxy
Signed-off-by: Alexey Stepanov <penguinolog@gmail.com>
1 parent dec689b commit 0786fc9

File tree

9 files changed

+73
-30
lines changed

9 files changed

+73
-30
lines changed

exec_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"async_api",
5252
)
5353

54-
__version__ = "3.1.2"
54+
__version__ = "3.1.3"
5555
__author__ = "Alexey Stepanov"
5656
__author_email__ = "penguinolog@gmail.com"
5757
__maintainers__ = {

exec_helpers/_ssh_client_base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,14 +618,9 @@ def execute_async(
618618

619619
if stdin is not None:
620620
if not _stdin.channel.closed:
621-
if isinstance(stdin, bytes):
622-
stdin_str = stdin.decode("utf-8")
623-
elif isinstance(stdin, bytearray):
624-
stdin_str = bytes(stdin).decode("utf-8")
625-
else:
626-
stdin_str = stdin
627-
628-
_stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8"))
621+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
622+
623+
_stdin.write(stdin_str)
629624
_stdin.flush()
630625
else:
631626
self.logger.warning("STDIN Send failed: closed channel")
@@ -773,12 +768,23 @@ def execute_through_host(
773768
)
774769

775770
# Make proxy objects for read
776-
stdout = channel.makefile("rb")
777-
stderr = channel.makefile_stderr("rb")
771+
_stdin = channel.makefile("wb") # type: paramiko.ChannelFile
772+
stdout = channel.makefile("rb") # type: paramiko.ChannelFile
773+
stderr = channel.makefile_stderr("rb") # type: paramiko.ChannelFile
778774

779775
channel.exec_command(command) # nosec # Sanitize on caller side
780776

781-
async_result = SshExecuteAsyncResult(interface=channel, stdin=None, stdout=stdout, stderr=stderr)
777+
stdin = kwargs.get("stdin", None)
778+
if stdin is not None:
779+
if not _stdin.channel.closed:
780+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
781+
782+
_stdin.write(stdin_str)
783+
_stdin.flush()
784+
else:
785+
self.logger.warning("STDIN Send failed: closed channel")
786+
787+
async_result = SshExecuteAsyncResult(interface=channel, stdin=_stdin, stdout=stdout, stderr=stderr)
782788

783789
# noinspection PyDictCreation
784790
result = self._exec_command(
@@ -787,6 +793,7 @@ def execute_through_host(
787793
timeout=timeout,
788794
verbose=verbose,
789795
log_mask_re=kwargs.get("log_mask_re", None),
796+
stdin=stdin,
790797
)
791798

792799
intermediate_channel.close()

exec_helpers/api.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, logger: logging.Logger, log_mask_re: typing.Optional[str] = N
5959
:type log_mask_re: typing.Optional[str]
6060
6161
.. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd
62-
.. versionchanged:: 1.3.5 make API public paramikoto use as interface
62+
.. versionchanged:: 1.3.5 make API public to use as interface
6363
"""
6464
self.__lock = threading.RLock()
6565
self.__logger = logger
@@ -279,18 +279,18 @@ def check_call(
279279
280280
.. versionchanged:: 1.2.0 default timeout 1 hour
281281
"""
282-
expected = proc_enums.exit_codes_to_enums(expected)
282+
expected_codes = proc_enums.exit_codes_to_enums(expected)
283283
ret = self.execute(command, verbose, timeout, **kwargs)
284-
if ret.exit_code not in expected:
284+
if ret.exit_code not in expected_codes:
285285
message = (
286286
"{append}Command {result.cmd!r} returned exit code "
287287
"{result.exit_code!s} while expected {expected!s}".format(
288-
append=error_info + "\n" if error_info else "", result=ret, expected=expected
288+
append=error_info + "\n" if error_info else "", result=ret, expected=expected_codes
289289
)
290290
)
291291
self.logger.error(msg=message)
292292
if raise_on_err:
293-
raise exceptions.CalledProcessError(result=ret, expected=expected)
293+
raise exceptions.CalledProcessError(result=ret, expected=expected_codes)
294294
return ret
295295

296296
def check_stderr(
@@ -335,3 +335,21 @@ def check_stderr(
335335
if raise_on_err:
336336
raise exceptions.CalledProcessError(result=ret, expected=kwargs.get("expected"))
337337
return ret
338+
339+
@staticmethod
340+
def _string_bytes_bytearray_as_bytes(src: typing.Union[str, bytes, bytearray]) -> bytes:
341+
"""Get bytes string from string/bytes/bytearray union.
342+
343+
:return: Byte string
344+
:rtype: bytes
345+
:raises TypeError: unexpected source type.
346+
"""
347+
if isinstance(src, bytes):
348+
return src
349+
if isinstance(src, bytearray):
350+
return bytes(src)
351+
if isinstance(src, str):
352+
return src.encode("utf-8")
353+
raise TypeError( # pragma: no cover
354+
"{!r} has unexpected type: not conform to Union[str, bytes, bytearray]".format(src)
355+
)

exec_helpers/async_api/api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,18 @@ async def check_call( # type: ignore
198198
:raises ExecHelperTimeoutError: Timeout exceeded
199199
:raises CalledProcessError: Unexpected exit code
200200
"""
201-
expected = proc_enums.exit_codes_to_enums(expected)
201+
expected_codes = proc_enums.exit_codes_to_enums(expected)
202202
ret = await self.execute(command, verbose, timeout, **kwargs)
203-
if ret.exit_code not in expected:
203+
if ret.exit_code not in expected_codes:
204204
message = (
205205
"{append}Command {result.cmd!r} returned exit code "
206206
"{result.exit_code!s} while expected {expected!s}".format(
207-
append=error_info + "\n" if error_info else "", result=ret, expected=expected
207+
append=error_info + "\n" if error_info else "", result=ret, expected=expected_codes
208208
)
209209
)
210210
self.logger.error(msg=message)
211211
if raise_on_err:
212-
raise exceptions.CalledProcessError(result=ret, expected=expected)
212+
raise exceptions.CalledProcessError(result=ret, expected=expected_codes)
213213
return ret
214214

215215
async def check_stderr( # type: ignore

exec_helpers/async_api/subprocess_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
logger = logging.getLogger(__name__) # type: logging.Logger
3636

3737

38-
# noinspection PyTypeHints
38+
# noinspection PyTypeHints,PyTypeChecker
3939
class SubprocessExecuteAsyncResult(subprocess_runner.SubprocessExecuteAsyncResult):
4040
"""Override original NamedTuple with proper typing."""
4141

exec_helpers/exec_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,9 @@ def __deserialize(self, fmt: str) -> typing.Any:
376376
:raises DeserializeValueError: Not valid source format
377377
"""
378378
try:
379-
if fmt == "json": # pylint: disable=no-else-return
379+
if fmt == "json":
380380
return json.loads(self.stdout_str, encoding="utf-8")
381-
elif fmt == "yaml":
381+
if fmt == "yaml":
382382
return yaml.safe_load(self.stdout_str)
383383
except Exception as e:
384384
tmpl = "{{self.cmd}} stdout is not valid {fmt}:\n" "{{stdout!r}}\n".format(fmt=fmt)

exec_helpers/subprocess_runner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,9 @@ def execute_async(
236236
if stdin is None:
237237
process_stdin = process.stdin
238238
else:
239-
if isinstance(stdin, str):
240-
stdin = stdin.encode(encoding="utf-8")
241-
elif isinstance(stdin, bytearray):
242-
stdin = bytes(stdin)
239+
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
243240
try:
244-
process.stdin.write(stdin)
241+
process.stdin.write(stdin_str)
245242
except OSError as exc:
246243
if exc.errno == errno.EINVAL:
247244
# bpo-19612, bpo-30418: On Windows, stdin.write() fails

test/test_ssh_client_execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan
339339
assert res.stdin.channel == res.interface
340340

341341
if stdin:
342-
res.stdin.write.assert_called_with("{stdin}\n".format(stdin=stdin).encode("utf-8"))
342+
res.stdin.write.assert_called_with(stdin.encode("utf-8"))
343343
res.stdin.flush.assert_called_once()
344344
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
345345
log.log.assert_called_once_with(level=logging.DEBUG, msg=command_log)

test/test_ssh_client_execute_throw_host.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,24 @@ def test_03_execute_get_pty(ssh, ssh_transport_channel) -> None:
181181
target = "127.0.0.2"
182182
ssh.execute_through_host(target, command, get_pty=True)
183183
ssh_transport_channel.get_pty.assert_called_with(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0)
184+
185+
186+
def test_04_execute_use_stdin(ssh, chan_makefile) -> None:
187+
target = "127.0.0.2"
188+
cmd = 'read line; echo "$line"'
189+
stdin = "test"
190+
res = ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
191+
assert res.stdin == stdin
192+
chan_makefile.stdin.write.assert_called_once_with(stdin.encode("utf-8"))
193+
chan_makefile.stdin.flush.assert_called_once()
194+
195+
196+
def test_05_execute_closed_stdin(ssh, ssh_transport_channel, get_logger) -> None:
197+
target = "127.0.0.2"
198+
cmd = 'read line; echo "$line"'
199+
stdin = "test"
200+
ssh_transport_channel.closed = True
201+
202+
ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
203+
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
204+
log.warning.assert_called_once_with("STDIN Send failed: closed channel")

0 commit comments

Comments
 (0)