Skip to content

Commit 56ef8e9

Browse files
author
Dan
committed
Made output decoding codec configurable at run_command, added test
1 parent 55b9775 commit 56ef8e9

File tree

4 files changed

+52
-23
lines changed

4 files changed

+52
-23
lines changed

embedded_server/embedded_server.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ class Server(paramiko.ServerInterface):
103103
"""
104104

105105
def __init__(self, transport, host_key, fail_auth=False,
106-
ssh_exception=False):
106+
ssh_exception=False,
107+
encoding='utf-8'):
107108
paramiko.ServerInterface.__init__(self)
108109
transport.load_server_moduli()
109110
transport.add_server_key(host_key)
@@ -113,6 +114,7 @@ def __init__(self, transport, host_key, fail_auth=False,
113114
self.fail_auth = fail_auth
114115
self.ssh_exception = ssh_exception
115116
self.host_key = host_key
117+
self.encoding = encoding
116118

117119
def check_channel_request(self, kind, chanid):
118120
return paramiko.OPEN_SUCCEEDED
@@ -163,12 +165,11 @@ def check_channel_forward_agent_request(self, channel):
163165
gevent.sleep()
164166
return True
165167

166-
def check_channel_exec_request(self, channel, cmd,
167-
encoding='utf-8'):
168+
def check_channel_exec_request(self, channel, cmd):
168169
logger.debug("Got exec request on channel %s for cmd %s" % (channel, cmd,))
169170
self.event.set()
170171
_env = os.environ
171-
_env['PYTHONIOENCODING'] = encoding
172+
_env['PYTHONIOENCODING'] = self.encoding
172173
if hasattr(channel, 'environment'):
173174
_env.update(channel.environment)
174175
process = gevent.subprocess.Popen(cmd, stdout=gevent.subprocess.PIPE,
@@ -182,7 +183,8 @@ def check_channel_exec_request(self, channel, cmd,
182183
def check_channel_env_request(self, channel, name, value):
183184
if not hasattr(channel, 'environment'):
184185
channel.environment = {}
185-
channel.environment.update({name.decode('utf8'): value.decode('utf8')})
186+
channel.environment.update({
187+
name.decode(self.encoding): value.decode(self.encoding)})
186188
return True
187189

188190
def _read_response(self, channel, process):
@@ -213,7 +215,8 @@ def make_socket(listen_ip, port=0):
213215
return sock
214216

215217
def listen(sock, fail_auth=False, ssh_exception=False,
216-
timeout=None):
218+
timeout=None,
219+
encoding='utf-8'):
217220
"""Run server and given a cmd_to_run, send given
218221
response to client connection. Returns (server, socket) tuple
219222
where server is a joinable server thread and socket is listening
@@ -228,13 +231,16 @@ def listen(sock, fail_auth=False, ssh_exception=False,
228231
return
229232
host, port = sock.getsockname()
230233
logger.info('Listening for connection on %s:%s..', host, port)
231-
return handle_ssh_connection(sock, fail_auth=fail_auth,
232-
timeout=timeout, ssh_exception=ssh_exception)
234+
return handle_ssh_connection(
235+
sock, fail_auth=fail_auth, timeout=timeout, ssh_exception=ssh_exception,
236+
encoding=encoding)
233237

234238
def _handle_ssh_connection(transport, fail_auth=False,
235-
ssh_exception=False):
239+
ssh_exception=False,
240+
encoding='utf-8'):
236241
server = Server(transport, host_key,
237-
fail_auth=fail_auth, ssh_exception=ssh_exception)
242+
fail_auth=fail_auth, ssh_exception=ssh_exception,
243+
encoding=encoding)
238244
try:
239245
transport.start_server(server=server)
240246
except paramiko.SSHException as e:
@@ -261,7 +267,8 @@ def _handle_ssh_connection(transport, fail_auth=False,
261267

262268
def handle_ssh_connection(sock,
263269
fail_auth=False, ssh_exception=False,
264-
timeout=None):
270+
timeout=None,
271+
encoding='utf-8'):
265272
conn, addr = sock.accept()
266273
logger.info('Got connection..')
267274
if timeout:
@@ -271,7 +278,8 @@ def handle_ssh_connection(sock,
271278
try:
272279
transport = paramiko.Transport(conn)
273280
return _handle_ssh_connection(transport, fail_auth=fail_auth,
274-
ssh_exception=ssh_exception)
281+
ssh_exception=ssh_exception,
282+
encoding=encoding)
275283
except Exception as e:
276284
logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),))
277285
traceback.print_exc()
@@ -281,14 +289,18 @@ def handle_ssh_connection(sock,
281289
pass
282290

283291
def start_server(sock, fail_auth=False, ssh_exception=False,
284-
timeout=None):
292+
timeout=None,
293+
encoding='utf-8'):
285294
return gevent.spawn(listen, sock, fail_auth=fail_auth,
286-
timeout=timeout, ssh_exception=ssh_exception)
295+
timeout=timeout, ssh_exception=ssh_exception,
296+
encoding=encoding)
287297

288298
def start_server_from_ip(ip, port=0,
289299
fail_auth=False, ssh_exception=False,
290-
timeout=None):
300+
timeout=None,
301+
encoding='utf-8'):
291302
server_sock = make_socket(ip, port=port)
292303
server = start_server(server_sock, fail_auth=fail_auth,
293-
ssh_exception=ssh_exception, timeout=timeout)
304+
ssh_exception=ssh_exception, timeout=timeout,
305+
encoding=encoding)
294306
return server, server_sock.getsockname()[1]

pssh/pssh_client.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def __init__(self, hosts,
336336
self.channel_timeout = channel_timeout
337337

338338
def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
339-
shell=None, use_shell=True, use_pty=True, host_args=None):
339+
shell=None, use_shell=True, use_pty=True, host_args=None,
340+
encoding='utf-8'):
340341
"""Run command on all hosts in parallel, honoring self.pool_size,
341342
and return output buffers.
342343
@@ -383,6 +384,9 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
383384
host list - :py:class:`pssh.exceptions.HostArgumentException` is raised \
384385
otherwise
385386
:type host_args: tuple or list
387+
:param encoding: Encoding to use for output. Must be valid
388+
`Python codec <https://docs.python.org/2.7/library/codecs.html>`_
389+
:type encoding: str
386390
387391
:rtype: Dictionary with host as key and \
388392
:py:class:`pssh.output.HostOutput` as value as per \
@@ -639,7 +643,7 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
639643
for host in self.hosts]
640644
for cmd in cmds:
641645
try:
642-
self.get_output(cmd, output)
646+
self.get_output(cmd, output, encoding=encoding)
643647
except Exception:
644648
if stop_on_errors:
645649
raise
@@ -676,7 +680,7 @@ def _exec_command(self, host, command, sudo=False, user=None,
676680
command, sudo=sudo, user=user, shell=shell,
677681
use_shell=use_shell, use_pty=use_pty)
678682

679-
def get_output(self, cmd, output):
683+
def get_output(self, cmd, output, encoding='utf-8'):
680684
"""Get output from command.
681685
682686
:param cmd: Command to get output from
@@ -733,10 +737,12 @@ def get_output(self, cmd, output):
733737
raise
734738
stdout = self.host_clients[host].read_output_buffer(
735739
stdout, callback=self.get_exit_codes,
736-
callback_args=(output,))
740+
callback_args=(output,),
741+
encoding=encoding)
737742
stderr = self.host_clients[host].read_output_buffer(
738743
stderr, prefix='\t[err]', callback=self.get_exit_codes,
739-
callback_args=(output,))
744+
callback_args=(output,),
745+
encoding=encoding)
740746
self._update_host_output(output, host, self._get_exit_code(channel),
741747
channel, stdout, stderr, stdin, cmd)
742748

pssh/ssh_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def exec_command(self, command, sudo=False, user=None,
262262

263263
def read_output_buffer(self, output_buffer, prefix='',
264264
callback=None,
265-
callback_args=None):
265+
callback_args=None,
266+
encoding='utf-8'):
266267
"""Read from output buffers and log to host_logger
267268
268269
:param output_buffer: Iterator containing buffer
@@ -274,7 +275,7 @@ def read_output_buffer(self, output_buffer, prefix='',
274275
:param callback_args: Arguments for call back function
275276
:type callback_args: tuple"""
276277
for line in output_buffer:
277-
output = line.strip().decode('utf8')
278+
output = line.strip().decode(encoding)
278279
host_logger.info("[%s]%s\t%s", self.host, prefix, output,)
279280
yield output
280281
if callback:

tests/test_pssh_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,16 @@ def test_ssh_client_utf_encoding(self):
947947
self.assertEqual(expected, stdout,
948948
msg="Got unexpected unicode output %s - expected %s" % (
949949
stdout, expected,))
950+
utf16_server, server_port = start_server_from_ip(
951+
self.host, encoding='utf-16')
952+
client = ParallelSSHClient([self.host], port=server_port,
953+
pkey=self.user_key)
954+
# File is already set to utf-8, cannot use utf-16 only representations
955+
# Using ascii characters encoded as utf-16 instead
956+
output = client.run_command(self.fake_cmd, encoding='utf-16')
957+
stdout = list(output[self.host]['stdout'])
958+
# import ipdb; ipdb.set_trace()
959+
self.assertEqual([self.fake_resp.decode('utf-16')], stdout)
950960

951961
def test_pty(self):
952962
cmd = "exit 0"

0 commit comments

Comments
 (0)