|
22 | 22 | import unittest |
23 | 23 | from pssh import ParallelSSHClient, UnknownHostException, \ |
24 | 24 | AuthenticationException, ConnectionErrorException, SSHException, logger as pssh_logger |
| 25 | +from pssh.utils import load_private_key |
25 | 26 | from embedded_server.embedded_server import start_server, make_socket, \ |
26 | 27 | logger as server_logger, paramiko_logger |
27 | 28 | from embedded_server.fake_agent import FakeAgent |
|
34 | 35 | import warnings |
35 | 36 | import shutil |
36 | 37 |
|
37 | | -USER_KEY = paramiko.RSAKey.from_private_key_file( |
38 | | - os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key'])) |
| 38 | +PKEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']) |
| 39 | +USER_KEY = paramiko.RSAKey.from_private_key_file(PKEY_FILENAME) |
39 | 40 |
|
40 | 41 | server_logger.setLevel(logging.DEBUG) |
41 | 42 | pssh_logger.setLevel(logging.DEBUG) |
@@ -648,3 +649,44 @@ def test_escaped_quotes(self): |
648 | 649 | self.assertEqual(expected, stdout, |
649 | 650 | msg="Got unexpected output. Expected %s, got %s" % ( |
650 | 651 | expected, stdout,)) |
| 652 | + |
| 653 | + def test_host_config(self): |
| 654 | + """Test per-host configuration functionality of ParallelSSHClient""" |
| 655 | + hosts = ['127.0.0.%01d' % n for n in xrange(1,3)] |
| 656 | + host_config = dict.fromkeys(hosts) |
| 657 | + servers = [] |
| 658 | + user = 'overriden_user' |
| 659 | + password = 'overriden_pass' |
| 660 | + for host in hosts: |
| 661 | + _socket = make_socket(host) |
| 662 | + port = _socket.getsockname()[1] |
| 663 | + host_config[host] = {} |
| 664 | + host_config[host]['port'] = port |
| 665 | + host_config[host]['user'] = user |
| 666 | + host_config[host]['password'] = password |
| 667 | + server = start_server(_socket, fail_auth=hosts.index(host)) |
| 668 | + servers.append((server, port)) |
| 669 | + pkey_data = load_private_key(PKEY_FILENAME) |
| 670 | + host_config[hosts[0]]['private_key'] = pkey_data |
| 671 | + client = ParallelSSHClient(hosts, host_config=host_config) |
| 672 | + output = client.run_command(self.fake_cmd, stop_on_errors=False) |
| 673 | + client.join(output) |
| 674 | + for host in hosts: |
| 675 | + self.assertTrue(host in output) |
| 676 | + try: |
| 677 | + raise output[hosts[1]]['exception'] |
| 678 | + except AuthenticationException, ex: |
| 679 | + pass |
| 680 | + else: |
| 681 | + raise AssertionError("Expected AutnenticationException on host %s", |
| 682 | + hosts[0]) |
| 683 | + self.assertFalse(output[hosts[1]]['exit_code'], |
| 684 | + msg="Execution failed on host %s" % (hosts[1],)) |
| 685 | + self.assertTrue(client.host_clients[hosts[0]].user == user, |
| 686 | + msg="Host config user override failed") |
| 687 | + self.assertTrue(client.host_clients[hosts[0]].password == password, |
| 688 | + msg="Host config password override failed") |
| 689 | + self.assertTrue(client.host_clients[hosts[0]].pkey == pkey_data, |
| 690 | + msg="Host config pkey override failed") |
| 691 | + for (server, _) in servers: |
| 692 | + server.kill() |
0 commit comments