Skip to content

Commit 034f4ee

Browse files
committed
Added providing client private key as a PKey class to SSH clients. Added fail auth flag to fake server password authentication. Made fake server use gevent. Changed tests to use pre-generated test user private key
1 parent a12d0ca commit 034f4ee

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

fake_server/fake_server.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@
88
paramiko repository
99
"""
1010

11-
# import multiprocessing
12-
import threading
13-
# import gevent
14-
# from gevent import monkey
15-
# monkey.patch_all()
11+
import gevent
12+
from gevent import monkey
13+
monkey.patch_all()
1614
import os
17-
import socket
18-
# from gevent import socket
19-
from threading import Event
15+
from gevent import socket
16+
from gevent.event import Event
2017
import sys
2118
import traceback
2219
import logging
@@ -38,6 +35,7 @@ def check_channel_request(self, kind, chanid):
3835
return paramiko.OPEN_SUCCEEDED
3936

4037
def check_auth_password(self, username, password):
38+
if self.fail_auth: return paramiko.AUTH_FAILED
4139
return paramiko.AUTH_SUCCESSFUL
4240

4341
def check_auth_publickey(self, username, key):
@@ -152,11 +150,7 @@ def handle_ssh_connection(cmd_req_response, sock, fail_auth = False):
152150
return
153151

154152
def start_server(cmd_req_response, sock, fail_auth=False):
155-
t = threading.Thread(target=listen, args=(cmd_req_response, sock,),
156-
kwargs={'fail_auth' : fail_auth})
157-
t.start()
158-
return t
159-
# return gevent.spawn(listen, cmd_req_response, sock, fail_auth=fail_auth)
153+
return gevent.spawn(listen, cmd_req_response, sock, fail_auth=fail_auth)
160154

161155
if __name__ == "__main__":
162156
logging.basicConfig()

pssh.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class SSHClient(object):
5252
overrides"""
5353

5454
def __init__(self, host,
55-
user=None, password=None, port=None):
55+
user=None, password=None, port=None,
56+
pkey=None):
5657
"""Connect to host honouring any user set configuration in ~/.ssh/config \
5758
or /etc/ssh/ssh_config
5859
@@ -87,6 +88,7 @@ def __init__(self, host,
8788
self.channel = None
8889
self.user = user
8990
self.password = password
91+
self.pkey = pkey
9092
self.port = port if port else 22
9193
self.host = resolved_address
9294
self._connect()
@@ -95,7 +97,8 @@ def _connect(self, retries=1):
9597
"""Connect to host, throw UnknownHost exception on DNS errors"""
9698
try:
9799
self.client.connect(self.host, username=self.user,
98-
password=self.password, port=self.port)
100+
password=self.password, port=self.port,
101+
pkey=self.pkey)
99102
except socket.gaierror, e:
100103
logger.error("Could not resolve host '%s'", self.host,)
101104
while retries < NUM_RETRIES:
@@ -201,12 +204,11 @@ def copy_file(self, local_file, remote_file):
201204
local_file, self.host, remote_file)
202205

203206
class ParallelSSHClient(object):
204-
"""
205-
Uses :mod:`pssh.SSHClient`, performs tasks over SSH on multiple hosts in \
207+
"""Uses :mod:`pssh.SSHClient`, performs tasks over SSH on multiple hosts in \
206208
parallel"""
207209

208210
def __init__(self, hosts,
209-
user=None, password=None, port=None,
211+
user=None, password=None, port=None, pkey=None,
210212
pool_size=10):
211213
"""
212214
:param hosts: Hosts to connect to
@@ -265,6 +267,7 @@ def __init__(self, hosts,
265267
self.user = user
266268
self.password = password
267269
self.port = port
270+
self.pkey = pkey
268271
# To hold host clients
269272
self.host_clients = dict((host, None) for host in hosts)
270273

@@ -304,7 +307,7 @@ def _exec_command(self, host, *args, **kwargs):
304307
if not self.host_clients[host]:
305308
self.host_clients[host] = SSHClient(host, user=self.user,
306309
password=self.password,
307-
port=self.port)
310+
port=self.port, pkey=self.pkey)
308311
return self.host_clients[host].exec_command(*args, **kwargs)
309312

310313
def get_stdout(self, greenlet):

tests/test_pssh_client.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,37 @@
55
import unittest
66
from pssh import ParallelSSHClient, UnknownHostException, \
77
AuthenticationException, ConnectionErrorException, _setup_logger
8-
from fake_server.fake_server import start_server, make_socket, logger as server_logger
8+
from fake_server.fake_server import start_server, make_socket, logger as server_logger, \
9+
paramiko_logger
910
import random
1011
import logging
1112
import gevent
1213
import threading
14+
import paramiko
15+
import os
1316

14-
_setup_logger(server_logger)
17+
# _setup_logger(server_logger)
18+
# _setup_logger(paramiko_logger)
19+
20+
USER_KEY = paramiko.RSAKey.from_private_key_file(
21+
os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']))
1522

1623
class ParallelSSHClientTest(unittest.TestCase):
1724

1825
def setUp(self):
1926
self.fake_cmd = 'fake cmd'
2027
self.fake_resp = 'fake response'
28+
self.user_key = USER_KEY
29+
self.listen_socket = make_socket('127.0.0.1')
30+
self.listen_port = self.listen_socket.getsockname()[1]
2131

32+
def tearDown(self):
33+
del self.listen_socket
34+
2235
def test_pssh_client_exec_command(self):
23-
sock = make_socket('127.0.0.1')
24-
listen_port = sock.getsockname()[1]
25-
server = start_server({ self.fake_cmd : self.fake_resp }, sock)
26-
client = ParallelSSHClient(['127.0.0.1'], port=listen_port)
36+
server = start_server({ self.fake_cmd : self.fake_resp }, self.listen_socket)
37+
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
38+
pkey=self.user_key)
2739
cmd = client.exec_command(self.fake_cmd)[0]
2840
output = client.get_stdout(cmd)
2941
expected = {'127.0.0.1' : {'exit_code' : 0}}
@@ -33,11 +45,10 @@ def test_pssh_client_exec_command(self):
3345
server.join()
3446

3547
def test_pssh_client_auth_failure(self):
36-
sock = make_socket('127.0.0.1')
37-
listen_port = sock.getsockname()[1]
3848
server = start_server({ self.fake_cmd : self.fake_resp },
39-
sock, fail_auth=True)
40-
client = ParallelSSHClient(['127.0.0.1'], port=listen_port)
49+
self.listen_socket, fail_auth=True)
50+
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
51+
pkey=self.user_key)
4152
cmd = client.exec_command(self.fake_cmd)[0]
4253
# Handle exception
4354
try:
@@ -47,3 +58,17 @@ def test_pssh_client_auth_failure(self):
4758
pass
4859
del client
4960
server.join()
61+
62+
def test_pssh_client_exec_command_password(self):
63+
"""Test password authentication. Fake server accepts any password
64+
even empty string"""
65+
server = start_server({ self.fake_cmd : self.fake_resp }, self.listen_socket)
66+
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
67+
password='')
68+
cmd = client.exec_command(self.fake_cmd)[0]
69+
output = client.get_stdout(cmd)
70+
expected = {'127.0.0.1' : {'exit_code' : 0}}
71+
self.assertEqual(expected, output,
72+
msg = "Got unexpected command output - %s" % (output,))
73+
del client
74+
server.join()

0 commit comments

Comments
 (0)