1919
2020"""Unittests for :mod:`pssh.ParallelSSHClient` class"""
2121
22+
2223import unittest
23- from pssh import ParallelSSHClient , UnknownHostException , \
24- AuthenticationException , ConnectionErrorException , SSHException , logger as pssh_logger
25- from pssh .utils import load_private_key
26- from embedded_server .embedded_server import start_server , make_socket , \
27- logger as server_logger , paramiko_logger
28- from embedded_server .fake_agent import FakeAgent
2924import random
3025import logging
31- import gevent
32- import paramiko
3326import os
3427import warnings
3528import shutil
3629import sys
3730
31+ import gevent
32+ from pssh import ParallelSSHClient , UnknownHostException , \
33+ AuthenticationException , ConnectionErrorException , SSHException , \
34+ logger as pssh_logger
35+ from pssh .exceptions import HostArgumentException
36+ from pssh .utils import load_private_key
37+ from embedded_server .embedded_server import start_server , make_socket , \
38+ logger as server_logger , paramiko_logger
39+ from embedded_server .fake_agent import FakeAgent
40+ from paramiko import RSAKey
3841
3942PKEY_FILENAME = os .path .sep .join ([os .path .dirname (__file__ ), 'test_client_private_key' ])
40- USER_KEY = paramiko . RSAKey .from_private_key_file (PKEY_FILENAME )
43+ USER_KEY = RSAKey .from_private_key_file (PKEY_FILENAME )
4144
4245server_logger .setLevel (logging .DEBUG )
4346pssh_logger .setLevel (logging .DEBUG )
@@ -228,7 +231,7 @@ def test_pssh_client_ssh_exception(self):
228231 client = ParallelSSHClient ([self .host ],
229232 user = 'fakey' , password = 'fakey' ,
230233 port = listen_port ,
231- pkey = paramiko . RSAKey .generate (1024 ),
234+ pkey = RSAKey .generate (1024 ),
232235 )
233236 self .assertRaises (SSHException , client .run_command , self .fake_cmd )
234237 del client
@@ -710,7 +713,7 @@ def test_ssh_exception(self):
710713 hosts = [host ]
711714 client = ParallelSSHClient (hosts , port = port ,
712715 user = 'fakey' , password = 'fakey' ,
713- pkey = paramiko . RSAKey .generate (1024 ))
716+ pkey = RSAKey .generate (1024 ))
714717 output = client .run_command (self .fake_cmd , stop_on_errors = False )
715718 gevent .sleep (1 )
716719 client .pool .join ()
@@ -826,7 +829,6 @@ def test_pssh_client_override_allow_agent_authentication(self):
826829 expected_exit_code = 0
827830 expected_stdout = [self .fake_resp ]
828831 expected_stderr = []
829-
830832 stdout = list (output [self .host ]['stdout' ])
831833 stderr = list (output [self .host ]['stderr' ])
832834 exit_code = output [self .host ]['exit_code' ]
@@ -848,5 +850,69 @@ def test_get_exit_codes_bad_output(self):
848850 self .assertFalse (self .client .get_exit_codes ({}))
849851 self .assertFalse (self .client .get_exit_code ({}))
850852
853+ def test_per_host_tuple_args (self ):
854+ server2_socket = make_socket ('127.0.0.2' , port = self .listen_port )
855+ server2_port = server2_socket .getsockname ()[1 ]
856+ server2 = start_server (server2_socket )
857+ server3_socket = make_socket ('127.0.0.3' , port = self .listen_port )
858+ server3_port = server3_socket .getsockname ()[1 ]
859+ server3 = start_server (server3_socket )
860+ hosts = [self .host , '127.0.0.2' , '127.0.0.3' ]
861+ host_args = ('arg1' , 'arg2' , 'arg3' )
862+ cmd = 'echo %s'
863+ client = ParallelSSHClient (hosts , port = self .listen_port ,
864+ pkey = self .user_key )
865+ output = client .run_command (cmd , host_args = host_args )
866+ for i , host in enumerate (hosts ):
867+ expected = [host_args [i ]]
868+ stdout = list (output [host ]['stdout' ])
869+ self .assertEqual (expected , stdout )
870+ self .assertTrue (output [host ]['exit_code' ] == 0 )
871+ host_args = (('arg1' , 'arg2' ), ('arg3' , 'arg4' ), ('arg5' , 'arg6' ),)
872+ cmd = 'echo %s %s'
873+ output = client .run_command (cmd , host_args = host_args )
874+ for i , host in enumerate (hosts ):
875+ expected = ["%s %s" % host_args [i ]]
876+ stdout = list (output [host ]['stdout' ])
877+ self .assertEqual (expected , stdout )
878+ self .assertTrue (output [host ]['exit_code' ] == 0 )
879+ self .assertRaises (HostArgumentException , client .run_command ,
880+ cmd , host_args = [host_args [0 ]])
881+
882+ def test_per_host_dict_args (self ):
883+ server2_socket = make_socket ('127.0.0.2' , port = self .listen_port )
884+ server2_port = server2_socket .getsockname ()[1 ]
885+ server2 = start_server (server2_socket )
886+ server3_socket = make_socket ('127.0.0.3' , port = self .listen_port )
887+ server3_port = server3_socket .getsockname ()[1 ]
888+ server3 = start_server (server3_socket )
889+ hosts = [self .host , '127.0.0.2' , '127.0.0.3' ]
890+ hosts_gen = (h for h in hosts )
891+ host_args = [dict (zip (('host_arg1' , 'host_arg2' ,),
892+ ('arg1-%s' % (i ,), 'arg2-%s' % (i ,),)))
893+ for i , _ in enumerate (hosts )]
894+ cmd = 'echo %(host_arg1)s %(host_arg2)s'
895+ client = ParallelSSHClient (hosts , port = self .listen_port ,
896+ pkey = self .user_key )
897+ output = client .run_command (cmd , host_args = host_args )
898+ for i , host in enumerate (hosts ):
899+ expected = ["%(host_arg1)s %(host_arg2)s" % host_args [i ]]
900+ stdout = list (output [host ]['stdout' ])
901+ self .assertEqual (expected , stdout )
902+ self .assertTrue (output [host ]['exit_code' ] == 0 )
903+ self .assertRaises (HostArgumentException , client .run_command ,
904+ cmd , host_args = [host_args [0 ]])
905+ # Host list generator should work also
906+ client .hosts = hosts_gen
907+ output = client .run_command (cmd , host_args = host_args )
908+ for i , host in enumerate (hosts ):
909+ expected = ["%(host_arg1)s %(host_arg2)s" % host_args [i ]]
910+ stdout = list (output [host ]['stdout' ])
911+ self .assertEqual (expected , stdout )
912+ self .assertTrue (output [host ]['exit_code' ] == 0 )
913+ client .hosts = (h for h in hosts )
914+ self .assertRaises (HostArgumentException , client .run_command ,
915+ cmd , host_args = [host_args [0 ]])
916+
851917if __name__ == '__main__' :
852918 unittest .main ()
0 commit comments