Skip to content

Commit 7c13c0f

Browse files
author
Dan
committed
Refactored embedded to be cleaner. Updated tests to use separate processes for embedded server
1 parent 37d8bd1 commit 7c13c0f

File tree

2 files changed

+247
-253
lines changed

2 files changed

+247
-253
lines changed

embedded_server/embedded_server.py

Lines changed: 100 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141
import sys
4242
if 'threading' in sys.modules:
4343
del sys.modules['threading']
44-
import gevent
44+
from gevent import monkey
45+
monkey.patch_all()
46+
from multiprocessing import Process
4547
import os
46-
import socket
48+
import gevent
4749
from gevent import socket
4850
from gevent.event import Event
4951
import sys
@@ -65,10 +67,10 @@ class Server(paramiko.ServerInterface):
6567
"""Implements :mod:`paramiko.ServerInterface` to provide an
6668
embedded SSH server implementation.
6769
68-
Start a `Server` with at least a transport and a host key.
70+
Start a `Server` with at least a host private key.
6971
7072
Any SSH2 client with public key or password authentication
71-
is allowed, only. Shell requests are not accepted.
73+
is allowed, only. Interactive shell requests are not accepted.
7274
7375
Implemented:
7476
* Direct tcp-ip channels (tunneling)
@@ -77,19 +79,88 @@ class Server(paramiko.ServerInterface):
7779
* Exec requests (run a command on server)
7880
7981
Not Implemented:
80-
* Shell requests
82+
* Interactive shell requests
8183
"""
8284

83-
def __init__(self, transport, host_key, fail_auth=False,
84-
ssh_exception=False):
85+
def __init__(self, host_key, fail_auth=False,
86+
ssh_exception=False,
87+
socket=None,
88+
port=0,
89+
listen_ip='127.0.0.1',
90+
timeout=None):
91+
if not socket:
92+
self.socket = make_socket(listen_ip, port)
93+
if not self.socket:
94+
msg = "Could not establish listening connection on %s:%s"
95+
logger.error(msg, listen_ip, port)
96+
raise Exception(msg, listen_ip, port)
97+
self.listen_ip = listen_ip
98+
self.listen_port = self.socket.getsockname()[1]
8599
self.event = Event()
86100
self.fail_auth = fail_auth
87101
self.ssh_exception = ssh_exception
88-
self.transport = transport
89102
self.host_key = host_key
90-
transport.load_server_moduli()
91-
transport.add_server_key(self.host_key)
92-
transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer)
103+
self.transport = None
104+
self.timeout = timeout
105+
106+
def start_listening(self):
107+
try:
108+
self.socket.listen(100)
109+
logger.info('Listening for connection on %s:%s..', self.listen_ip,
110+
self.listen_port)
111+
except Exception as e:
112+
logger.error('*** Listen failed: %s' % (str(e),))
113+
traceback.print_exc()
114+
raise
115+
conn, addr = self.socket.accept()
116+
logger.info('Got connection..')
117+
if self.timeout:
118+
logger.debug("SSH server sleeping for %s then raising socket.timeout",
119+
self.timeout)
120+
gevent.Timeout(self.timeout).start()
121+
self.transport = paramiko.Transport(conn)
122+
self.transport.load_server_moduli()
123+
self.transport.add_server_key(self.host_key)
124+
self.transport.set_subsystem_handler('sftp', paramiko.SFTPServer,
125+
StubSFTPServer)
126+
try:
127+
self.transport.start_server(server=self)
128+
except paramiko.SSHException as e:
129+
logger.exception('SSH negotiation failed')
130+
raise
131+
132+
def run(self):
133+
while True:
134+
try:
135+
self.start_listening()
136+
except Exception:
137+
logger.exception("Error occured starting server")
138+
continue
139+
try:
140+
self.accept_connections()
141+
except Exception as e:
142+
logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),))
143+
traceback.print_exc()
144+
try:
145+
self.transport.close()
146+
except Exception:
147+
pass
148+
raise
149+
150+
def accept_connections(self):
151+
while True:
152+
gevent.sleep(0)
153+
channel = self.transport.accept(20)
154+
if not channel:
155+
logger.error("Could not establish channel")
156+
return
157+
while self.transport.is_active():
158+
logger.debug("Transport active, waiting..")
159+
gevent.sleep(1)
160+
while not channel.send_ready():
161+
gevent.sleep(.2)
162+
channel.close()
163+
gevent.sleep(0)
93164

94165
def check_channel_request(self, kind, chanid):
95166
return paramiko.OPEN_SUCCEEDED
@@ -157,7 +228,7 @@ def _read_response(self, channel, process):
157228
channel.send_exit_status(process.returncode)
158229
logger.debug("Command finished with return code %s", process.returncode)
159230
# Let clients consume output from channel before closing
160-
gevent.sleep(.2)
231+
gevent.sleep(.1)
161232
channel.close()
162233

163234
def make_socket(listen_ip, port=0):
@@ -172,92 +243,24 @@ def make_socket(listen_ip, port=0):
172243
return
173244
return sock
174245

175-
def listen(sock, fail_auth=False, ssh_exception=False,
176-
timeout=None):
177-
"""Run server and given a cmd_to_run, send given
178-
response to client connection. Returns (server, socket) tuple
179-
where server is a joinable server thread and socket is listening
180-
socket of server.
181-
"""
182-
listen_ip, listen_port = sock.getsockname()
183-
if not sock:
184-
logger.error("Could not establish listening connection on %s:%s",
185-
listen_ip, listen_port)
186-
return
246+
def start_server(listen_ip, fail_auth=False, ssh_exception=False,
247+
timeout=None,
248+
listen_port=0):
249+
server = Server(host_key, listen_ip=listen_ip, port=listen_port,
250+
fail_auth=fail_auth, ssh_exception=ssh_exception,
251+
timeout=timeout)
187252
try:
188-
sock.listen(100)
189-
logger.info('Listening for connection on %s:%s..', listen_ip,
190-
listen_port)
191-
except Exception as e:
192-
logger.error('*** Listen failed: %s' % (str(e),))
193-
traceback.print_exc()
194-
return
195-
handle_ssh_connection(sock, fail_auth=fail_auth,
196-
timeout=timeout, ssh_exception=ssh_exception)
197-
198-
def _handle_ssh_connection(transport, fail_auth=False,
199-
ssh_exception=False):
200-
server = Server(transport, HOST_KEY,
201-
fail_auth=fail_auth, ssh_exception=ssh_exception)
202-
# server.run()
203-
try:
204-
transport.start_server(server=server)
205-
except paramiko.SSHException as e:
206-
logger.exception('SSH negotiation failed')
207-
except Exception:
208-
logger.exception("Error occured starting server")
209-
return
210-
while True:
211-
gevent.sleep(0)
212-
channel = transport.accept(20)
213-
if not channel:
214-
logger.error("Could not establish channel")
215-
return
216-
while transport.is_active():
217-
logger.debug("Transport active, waiting..")
218-
gevent.sleep(1)
219-
while not channel.send_ready():
220-
gevent.sleep(.2)
221-
channel.close()
222-
gevent.sleep(0)
223-
224-
def handle_ssh_connection(sock,
225-
fail_auth=False, ssh_exception=False,
226-
timeout=None):
227-
conn, addr = sock.accept()
228-
logger.info('Got connection..')
229-
if timeout:
230-
logger.debug("SSH server sleeping for %s then raising socket.timeout",
231-
timeout)
232-
gevent.Timeout(timeout).start()
233-
try:
234-
transport = paramiko.Transport(conn)
235-
_handle_ssh_connection(transport, fail_auth=fail_auth,
236-
ssh_exception=ssh_exception)
237-
except Exception as e:
238-
logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),))
239-
traceback.print_exc()
240-
try:
241-
transport.close()
242-
except:
243-
pass
244-
return
245-
246-
def start_server(sock, fail_auth=False, ssh_exception=False,
247-
timeout=None):
248-
g = gevent.spawn(listen, sock, fail_auth=fail_auth,
249-
timeout=timeout, ssh_exception=ssh_exception)
250-
try:
251-
g.join()
253+
server.run()
252254
except KeyboardInterrupt:
253255
sys.exit(0)
254256

255-
if __name__ == "__main__":
256-
logging.basicConfig()
257-
logger.setLevel(logging.DEBUG)
258-
sock = make_socket('127.0.0.1')
259-
server = start_server(sock)
260-
try:
261-
server.join()
262-
except KeyboardInterrupt:
263-
sys.exit(0)
257+
def start_server_process(listen_ip, fail_auth=False, ssh_exception=False,
258+
timeout=None, listen_port=0):
259+
server = Process(target=start_server, args=(listen_ip,),
260+
kwargs={
261+
'listen_port': listen_port,
262+
'fail_auth': fail_auth,
263+
'ssh_exception': ssh_exception,
264+
'timeout': timeout,
265+
})
266+
return server

0 commit comments

Comments
 (0)