Skip to content

Commit a9c02e2

Browse files
author
Dan
committed
Suppress gevent printing exceptions to stderr - we are handling them in ParallelSSH code. Resolves #30
1 parent bc0e018 commit a9c02e2

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

fake_server/fake_server.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,29 @@
4747
host_key = paramiko.RSAKey(filename = os.path.sep.join([os.path.dirname(__file__), 'rsa.key']))
4848

4949
class Server (paramiko.ServerInterface):
50-
def __init__(self, transport, cmd_req_response = {}, fail_auth=False):
50+
def __init__(self, transport, cmd_req_response = {}, fail_auth=False,
51+
ssh_exception=False):
5152
self.event = Event()
5253
self.cmd_req_response = cmd_req_response
5354
self.fail_auth = fail_auth
55+
self.ssh_exception = ssh_exception
5456
self.transport = transport
5557

5658
def check_channel_request(self, kind, chanid):
5759
return paramiko.OPEN_SUCCEEDED
5860

5961
def check_auth_password(self, username, password):
60-
if self.fail_auth: return paramiko.AUTH_FAILED
62+
if self.fail_auth:
63+
return paramiko.AUTH_FAILED
64+
if self.ssh_exception:
65+
raise paramiko.SSHException()
6166
return paramiko.AUTH_SUCCESSFUL
6267

6368
def check_auth_publickey(self, username, key):
64-
if self.fail_auth: return paramiko.AUTH_FAILED
69+
if self.fail_auth:
70+
return paramiko.AUTH_FAILED
71+
if self.ssh_exception:
72+
raise paramiko.SSHException()
6573
return paramiko.AUTH_SUCCESSFUL
6674

6775
def get_allowed_auths(self, username):
@@ -131,7 +139,7 @@ def make_socket(listen_ip, port=0):
131139
return
132140
return sock
133141

134-
def listen(cmd_req_response, sock, fail_auth=False,
142+
def listen(cmd_req_response, sock, fail_auth=False, ssh_exception=False,
135143
timeout=None):
136144
"""Run a fake ssh server and given a cmd_to_run, send given \
137145
response to client connection. Returns (server, socket) tuple \
@@ -149,17 +157,18 @@ def listen(cmd_req_response, sock, fail_auth=False,
149157
traceback.print_exc()
150158
return
151159
handle_ssh_connection(cmd_req_response, sock, fail_auth=fail_auth,
152-
timeout=timeout)
160+
timeout=timeout, ssh_exception=ssh_exception)
153161

154-
def _handle_ssh_connection(cmd_req_response, transport, fail_auth=False):
162+
def _handle_ssh_connection(cmd_req_response, transport, fail_auth=False,
163+
ssh_exception=False):
155164
try:
156165
transport.load_server_moduli()
157166
except:
158167
return
159168
transport.add_server_key(host_key)
160169
transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer)
161170
server = Server(transport, cmd_req_response=cmd_req_response,
162-
fail_auth=fail_auth)
171+
fail_auth=fail_auth, ssh_exception=ssh_exception)
163172
try:
164173
transport.start_server(server=server)
165174
except paramiko.SSHException, e:
@@ -180,7 +189,7 @@ def _handle_ssh_connection(cmd_req_response, transport, fail_auth=False):
180189
channel.close()
181190

182191
def handle_ssh_connection(cmd_req_response, sock,
183-
fail_auth=False,
192+
fail_auth=False, ssh_exception=False,
184193
timeout=None):
185194
conn, addr = sock.accept()
186195
logger.info('Got connection..')
@@ -190,7 +199,8 @@ def handle_ssh_connection(cmd_req_response, sock,
190199
gevent.Timeout(timeout).start()
191200
try:
192201
transport = paramiko.Transport(conn)
193-
_handle_ssh_connection(cmd_req_response, transport, fail_auth=fail_auth)
202+
_handle_ssh_connection(cmd_req_response, transport, fail_auth=fail_auth,
203+
ssh_exception=ssh_exception)
194204
except Exception, e:
195205
logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),))
196206
traceback.print_exc()
@@ -200,10 +210,10 @@ def handle_ssh_connection(cmd_req_response, sock,
200210
pass
201211
return
202212

203-
def start_server(cmd_req_response, sock, fail_auth=False,
213+
def start_server(cmd_req_response, sock, fail_auth=False, ssh_exception=False,
204214
timeout=None):
205215
return gevent.spawn(listen, cmd_req_response, sock, fail_auth=fail_auth,
206-
timeout=timeout)
216+
timeout=timeout, ssh_exception=ssh_exception)
207217

208218
if __name__ == "__main__":
209219
logging.basicConfig()

pssh.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from gevent import monkey
3232
monkey.patch_all()
3333
import gevent.pool
34+
import gevent.hub
35+
gevent.hub.Hub.NOT_ERROR=(Exception,)
3436
import warnings
3537
from socket import gaierror as sock_gaierror, error as sock_error
3638
import logging
@@ -202,12 +204,12 @@ def _connect(self, client, host, port, sock=None, retries=1):
202204
str(error_type), self.host, self.port,
203205
retries, self.num_retries,)
204206
except paramiko.AuthenticationException, ex:
205-
raise AuthenticationException(ex)
207+
raise AuthenticationException(ex.message)
206208
# SSHException is more general so should be below other types
207209
# of SSH failure
208210
except paramiko.SSHException, ex:
209211
logger.error("General SSH error - %s", ex)
210-
raise SSHException(ex)
212+
raise SSHException(ex.message)
211213

212214
def exec_command(self, command, sudo=False, user=None, **kwargs):
213215
"""Wrapper to :mod:`paramiko.SSHClient.exec_command`

tests/test_pssh_client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import unittest
2323
from pssh import ParallelSSHClient, UnknownHostException, \
24-
AuthenticationException, ConnectionErrorException, logger as pssh_logger
24+
AuthenticationException, ConnectionErrorException, SSHException, logger as pssh_logger
2525
from fake_server.fake_server import start_server, make_socket, \
2626
logger as server_logger, paramiko_logger
2727
import random
@@ -190,6 +190,21 @@ def test_pssh_client_auth_failure(self):
190190
del client
191191
server.join()
192192

193+
def test_pssh_client_ssh_exception(self):
194+
server = start_server({ self.fake_cmd : self.fake_resp },
195+
self.listen_socket,
196+
ssh_exception=True)
197+
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
198+
pkey=self.user_key)
199+
# Handle exception
200+
try:
201+
client.run_command(self.fake_cmd)
202+
raise Exception("Expected SSHException, got none")
203+
except SSHException, ex:
204+
pass
205+
del client
206+
server.join()
207+
193208
def test_pssh_client_timeout(self):
194209
server_timeout=0.2
195210
client_timeout=server_timeout-0.1

0 commit comments

Comments
 (0)