Skip to content

Commit 19ac84b

Browse files
author
Dan
committed
Updated run_command to use named arguments. Added environment keyword argument to run_command and unittest. Added unittests for run_command parameters. Updated docstrings
1 parent 46d208d commit 19ac84b

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

pssh/pssh_client.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ def __init__(self, hosts,
335335
self.host_config = host_config if host_config else {}
336336
self.channel_timeout = channel_timeout
337337

338-
def run_command(self, *args, **kwargs):
338+
def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
339+
shell=None, use_shell=True, use_pty=True, host_args=None,
340+
environment=None):
339341
"""Run command on all hosts in parallel, honoring self.pool_size,
340342
and return output buffers.
341343
@@ -351,8 +353,8 @@ def run_command(self, *args, **kwargs):
351353
``stop_on_errors=False`` in which case exceptions are added to host
352354
output instead.
353355
354-
:param args: Positional arguments for command
355-
:type args: tuple
356+
:param command: Command to run
357+
:type command: str
356358
:param sudo: (Optional) Run with sudo. Defaults to False
357359
:type sudo: bool
358360
:param user: (Optional) User to run command as. Requires sudo access \
@@ -382,6 +384,11 @@ def run_command(self, *args, **kwargs):
382384
host list - :py:class:`pssh.exceptions.HostArgumentException` is raised \
383385
otherwise
384386
:type host_args: tuple or list
387+
:param environment: (Optional) Environment variables to be exposed to \
388+
command to be run. This requires that ``AcceptEnv`` server setting \
389+
is enabled - variables will silently not be set otherwise
390+
:type environment: dict
391+
385392
:rtype: Dictionary with host as key and \
386393
:py:class:`pssh.output.HostOutput` as value as per \
387394
:py:func:`pssh.pssh_client.ParallelSSHClient.get_output`
@@ -617,22 +624,25 @@ def run_command(self, *args, **kwargs):
617624
writing to stdin
618625
619626
"""
620-
stop_on_errors = kwargs.pop('stop_on_errors', True)
621-
host_args = kwargs.pop('host_args', None)
622627
output = {}
623628
if host_args:
624629
try:
625630
cmds = [self.pool.spawn(self._exec_command, host,
626-
args[0] % host_args[host_i],
627-
*args[1:], **kwargs)
631+
command % host_args[host_i],
632+
sudo=sudo, user=user, shell=shell,
633+
use_shell=use_shell, use_pty=use_pty,
634+
environment=environment)
628635
for host_i, host in enumerate(self.hosts)]
629636
except IndexError:
630637
raise HostArgumentException(
631638
"Number of host arguments provided does not match "
632639
"number of hosts ")
633640
else:
634641
cmds = [self.pool.spawn(
635-
self._exec_command, host, *args, **kwargs)
642+
self._exec_command, host, command,
643+
sudo=sudo, user=user, shell=shell,
644+
use_shell=use_shell, use_pty=use_pty,
645+
environment=environment)
636646
for host in self.hosts]
637647
for cmd in cmds:
638648
try:
@@ -648,11 +658,14 @@ def _get_host_config_values(self, host):
648658
_password = self.host_config.get(host, {}).get('password', self.password)
649659
_pkey = self.host_config.get(host, {}).get('private_key', self.pkey)
650660
return _user, _port, _password, _pkey
651-
652-
def _exec_command(self, host, *args, **kwargs):
661+
662+
def _exec_command(self, host, command, sudo=False, user=None,
663+
shell=None, use_shell=True, use_pty=True,
664+
environment=None):
653665
"""Make SSHClient, run command on host"""
654666
if not host in self.host_clients or not self.host_clients[host]:
655667
_user, _port, _password, _pkey = self._get_host_config_values(host)
668+
_user = user if user else _user
656669
self.host_clients[host] = SSHClient(host, user=_user,
657670
password=_password,
658671
port=_port, pkey=_pkey,
@@ -667,21 +680,23 @@ def _exec_command(self, host, *args, **kwargs):
667680
allow_agent=self.allow_agent,
668681
agent=self.agent,
669682
channel_timeout=self.channel_timeout)
670-
return self.host_clients[host].exec_command(*args, **kwargs)
683+
return self.host_clients[host].exec_command(
684+
command, sudo=sudo, user=user, shell=shell,
685+
use_shell=use_shell, use_pty=use_pty, environment=environment)
671686

672687
def get_output(self, cmd, output):
673688
"""Get output from command.
674-
689+
675690
:param cmd: Command to get output from
676691
:type cmd: :py:class:`gevent.Greenlet`
677692
:param output: Dictionary containing \
678693
:py:class:`pssh.output.HostOutput` values to be updated with output \
679694
from cmd
680695
:type output: dict
681696
:rtype: None
682-
697+
683698
`output` parameter is modified in-place and has the following structure
684-
699+
685700
::
686701
687702
{'myhost1':
@@ -705,7 +720,8 @@ def get_output(self, cmd, output):
705720
for line in output[host].stdout:
706721
print(line)
707722
<stdout>
708-
# Get exit code after command has finished
723+
# Get exit code for a particular host's output after command
724+
# has finished
709725
self.get_exit_code(output[host])
710726
0
711727
@@ -763,7 +779,7 @@ def finished(self, output):
763779
:rtype: bool
764780
"""
765781
for host in output:
766-
chan = output[host]['channel']
782+
chan = output[host].channel
767783
if chan is not None and not chan.closed:
768784
return False
769785
return True
@@ -787,7 +803,7 @@ def get_exit_code(self, host_output):
787803
if not 'channel' in host_output:
788804
logger.error("%s does not look like host output..", host_output,)
789805
return
790-
channel = host_output['channel']
806+
channel = host_output.channel
791807
return self._get_exit_code(channel)
792808

793809
def _get_exit_code(self, channel):

pssh/ssh_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _connect(self, client, host, port, sock=None, retries=1,
203203
def exec_command(self, command, sudo=False, user=None,
204204
shell=None,
205205
use_shell=True, use_pty=True,
206-
**kwargs):
206+
environment=None):
207207
"""Wrapper to :py:func:`paramiko.SSHClient.exec_command`
208208
209209
Opens a new SSH session with a new pty and runs command before yielding
@@ -241,6 +241,8 @@ def exec_command(self, command, sudo=False, user=None,
241241
channel.get_pty()
242242
if self.channel_timeout:
243243
channel.settimeout(self.channel_timeout)
244+
if environment:
245+
channel.update_environment(environment)
244246
stdout, stderr, stdin = channel.makefile('rb'), channel.makefile_stderr('rb'), \
245247
channel.makefile('wb')
246248
for _char in ['\\', '"', '$', '`']:
@@ -256,7 +258,7 @@ def exec_command(self, command, sudo=False, user=None,
256258
else:
257259
_command += '"%s"' % (command,)
258260
logger.debug("Running parsed command %s on %s", _command, self.host)
259-
channel.exec_command(_command, **kwargs)
261+
channel.exec_command(_command)
260262
logger.debug("Command started")
261263
sleep(0)
262264
return channel, self.host, stdout, stderr, stdin

tests/test_pssh_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,5 +976,35 @@ def test_output_attributes(self):
976976
self.assertTrue(hasattr(output[self.host], 'exception'))
977977
self.assertTrue(hasattr(output[self.host], 'exit_code'))
978978

979+
def test_run_command_user_sudo(self):
980+
user = 'cmd_user'
981+
output = self.client.run_command(self.fake_cmd, user=user)
982+
self.client.join(output)
983+
stderr = list(output[self.host].stderr)
984+
self.assertTrue(len(stderr) > 0)
985+
self.assertTrue(user in stderr[0])
986+
987+
def test_run_command_shell(self):
988+
output = self.client.run_command(self.fake_cmd, shell="bash -c")
989+
self.client.join(output)
990+
stdout = list(output[self.host].stdout)
991+
self.assertEqual(stdout, [self.fake_resp])
992+
993+
def test_run_command_no_shell(self):
994+
output = self.client.run_command('id', use_shell=False)
995+
self.client.join(output)
996+
stdout = list(output[self.host].stdout)
997+
self.assertTrue(len(stdout) > 0)
998+
self.assertTrue(output[self.host].exit_code == 0)
999+
1000+
def test_run_command_environment(self):
1001+
env = {'ENV_VARIABLE': 'env value'}
1002+
output = self.client.run_command('echo ${ENV_VARIABLE}',
1003+
environment=env)
1004+
self.client.join(output)
1005+
stdout = list(output[self.host].stdout)
1006+
expected = [env.values()[0]]
1007+
self.assertEqual(stdout, expected)
1008+
9791009
if __name__ == '__main__':
9801010
unittest.main()

0 commit comments

Comments
 (0)