Skip to content

Commit 1ca6a20

Browse files
author
Dan
committed
Fixed long running command tests. Updated embedded server to correctly handle long running client connections. Added unittest for new get exit codes function. Updated tests for readability
1 parent d4a7eb6 commit 1ca6a20

File tree

3 files changed

+78
-58
lines changed

3 files changed

+78
-58
lines changed

embedded_server/embedded_server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
paramiko repository
2626
"""
2727

28-
from gevent import monkey
29-
monkey.patch_all()
3028
import gevent
3129
import os
3230
import socket
@@ -37,8 +35,8 @@
3735
import logging
3836
import paramiko
3937
import time
40-
from stub_sftp import StubSFTPServer
41-
from tunnel import Tunneler
38+
from .stub_sftp import StubSFTPServer
39+
from .tunnel import Tunneler
4240
import gevent.subprocess
4341

4442
logger = logging.getLogger("embedded_server")
@@ -112,7 +110,9 @@ def _read_response(self, channel, process):
112110
process.communicate()
113111
channel.send_exit_status(process.returncode)
114112
logger.debug("Command finished with return code %s", process.returncode)
115-
gevent.sleep(0)
113+
# Let clients consume output from channel before closing
114+
gevent.sleep(.2)
115+
channel.close()
116116

117117
def make_socket(listen_ip, port=0):
118118
"""Make socket on given address and available port chosen by OS"""

pssh/pssh_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,16 @@ def get_output(self, cmd, output):
316316
'stderr' : stderr,
317317
'cmd' : cmd, })
318318

319+
def get_exit_codes(self, output):
320+
"""Get exit code for all hosts in output if available.
321+
Output parameter is modified in-place.
322+
323+
:param output: As returned by `self.get_output`
324+
:rtype: None
325+
"""
326+
for host in output:
327+
output[host].update({'exit_code': self.get_exit_code(output[host])})
328+
319329
def get_exit_code(self, host_output):
320330
"""Get exit code from host output if available
321331
:param host_output: Per host output as returned by `self.get_output`

tests/test_pssh_client.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -44,59 +44,50 @@ class ParallelSSHClientTest(unittest.TestCase):
4444
def setUp(self):
4545
self.fake_cmd = 'echo me'
4646
self.fake_resp = 'me'
47-
self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; done' % (lines,)
47+
self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,)
4848
self.user_key = USER_KEY
49-
self.listen_socket = make_socket('127.0.0.1')
49+
self.host = '127.0.0.1'
50+
self.listen_socket = make_socket(self.host)
5051
self.listen_port = self.listen_socket.getsockname()[1]
5152
self.server = start_server(self.listen_socket)
5253

53-
def long_running_response(self, responses):
54-
i = 0
55-
while True:
56-
if i >= responses:
57-
raise StopIteration
58-
gevent.sleep(0)
59-
yield 'long running response'
60-
gevent.sleep(1)
61-
i += 1
62-
6354
def tearDown(self):
6455
del self.server
6556
del self.listen_socket
6657

6758
def test_pssh_client_exec_command(self):
68-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
59+
client = ParallelSSHClient([self.host], port=self.listen_port,
6960
pkey=self.user_key)
7061
cmd = client.exec_command(self.fake_cmd)[0]
7162
output = client.get_stdout(cmd)
72-
expected = {'127.0.0.1' : {'exit_code' : 0}}
63+
expected = {self.host : {'exit_code' : 0}}
7364
self.assertEqual(expected, output,
7465
msg="Got unexpected command output - %s" % (output,))
75-
self.assertTrue(output['127.0.0.1']['exit_code'] == 0)
66+
self.assertTrue(output[self.host]['exit_code'] == 0)
7667

7768
def test_pssh_client_no_stdout_non_zero_exit_code(self):
78-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
69+
client = ParallelSSHClient([self.host], port=self.listen_port,
7970
pkey=self.user_key)
8071
output = client.run_command('exit 1')
8172
expected_exit_code = 1
82-
exit_code = output['127.0.0.1']['exit_code']
73+
exit_code = output[self.host]['exit_code']
8374
client.pool.join()
8475
self.assertEqual(expected_exit_code, exit_code,
8576
msg="Got unexpected exit code - %s, expected %s" %
8677
(exit_code,
8778
expected_exit_code,))
8879

8980
def test_pssh_client_exec_command_get_buffers(self):
90-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
81+
client = ParallelSSHClient([self.host], port=self.listen_port,
9182
pkey=self.user_key)
9283
cmd = client.exec_command(self.fake_cmd)[0]
9384
output = client.get_stdout(cmd, return_buffers=True)
9485
expected_exit_code = 0
9586
expected_stdout = [self.fake_resp]
9687
expected_stderr = []
97-
exit_code = output['127.0.0.1']['exit_code']
98-
stdout = list(output['127.0.0.1']['stdout'])
99-
stderr = list(output['127.0.0.1']['stderr'])
88+
exit_code = output[self.host]['exit_code']
89+
stdout = list(output[self.host]['stdout'])
90+
stderr = list(output[self.host]['stderr'])
10091
self.assertEqual(expected_exit_code, exit_code,
10192
msg="Got unexpected exit code - %s, expected %s" %
10293
(exit_code,
@@ -111,15 +102,15 @@ def test_pssh_client_exec_command_get_buffers(self):
111102
expected_stderr,))
112103

113104
def test_pssh_client_run_command_get_output(self):
114-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
105+
client = ParallelSSHClient([self.host], port=self.listen_port,
115106
pkey=self.user_key)
116107
output = client.run_command(self.fake_cmd)
117108
expected_exit_code = 0
118109
expected_stdout = [self.fake_resp]
119110
expected_stderr = []
120-
exit_code = output['127.0.0.1']['exit_code']
121-
stdout = list(output['127.0.0.1']['stdout'])
122-
stderr = list(output['127.0.0.1']['stderr'])
111+
exit_code = output[self.host]['exit_code']
112+
stdout = list(output[self.host]['stdout'])
113+
stderr = list(output[self.host]['stderr'])
123114
self.assertEqual(expected_exit_code, exit_code,
124115
msg="Got unexpected exit code - %s, expected %s" %
125116
(exit_code,
@@ -134,7 +125,7 @@ def test_pssh_client_run_command_get_output(self):
134125
expected_stderr,))
135126

136127
def test_pssh_client_run_command_get_output_explicit(self):
137-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
128+
client = ParallelSSHClient([self.host], port=self.listen_port,
138129
pkey=self.user_key)
139130
out = client.run_command(self.fake_cmd)
140131
cmds = [cmd for host in out for cmd in [out[host]['cmd']]]
@@ -144,9 +135,9 @@ def test_pssh_client_run_command_get_output_explicit(self):
144135
expected_exit_code = 0
145136
expected_stdout = [self.fake_resp]
146137
expected_stderr = []
147-
exit_code = output['127.0.0.1']['exit_code']
148-
stdout = list(output['127.0.0.1']['stdout'])
149-
stderr = list(output['127.0.0.1']['stderr'])
138+
exit_code = output[self.host]['exit_code']
139+
stdout = list(output[self.host]['stdout'])
140+
stderr = list(output[self.host]['stderr'])
150141
self.assertEqual(expected_exit_code, exit_code,
151142
msg="Got unexpected exit code - %s, expected %s" %
152143
(exit_code,
@@ -163,21 +154,21 @@ def test_pssh_client_run_command_get_output_explicit(self):
163154

164155
def test_pssh_client_run_long_command(self):
165156
expected_lines = 5
166-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
157+
client = ParallelSSHClient([self.host], port=self.listen_port,
167158
pkey=self.user_key)
168159
output = client.run_command(self.long_cmd(expected_lines))
169-
self.assertTrue('127.0.0.1' in output, msg="Got no output for command")
170-
stdout = list(output['127.0.0.1']['stdout'])
160+
self.assertTrue(self.host in output, msg="Got no output for command")
161+
stdout = list(output[self.host]['stdout'])
171162
self.assertTrue(len(stdout) == expected_lines,
172163
msg="Expected %s lines of response, got %s" % (
173164
expected_lines, len(stdout)))
174165
del client
175166

176167
def test_pssh_client_auth_failure(self):
177-
listen_socket = make_socket('127.0.0.1')
168+
listen_socket = make_socket(self.host)
178169
listen_port = listen_socket.getsockname()[1]
179170
server = start_server(listen_socket, fail_auth=True)
180-
client = ParallelSSHClient(['127.0.0.1'], port=listen_port,
171+
client = ParallelSSHClient([self.host], port=listen_port,
181172
pkey=self.user_key)
182173
cmd = client.exec_command(self.fake_cmd)[0]
183174
# Handle exception
@@ -195,7 +186,7 @@ def test_pssh_client_hosts_list_part_failure(self):
195186
server2_socket = make_socket('127.0.0.2', port=self.listen_port)
196187
server2_port = server2_socket.getsockname()[1]
197188
server2 = start_server(server2_socket, fail_auth=True)
198-
hosts = ['127.0.0.1', '127.0.0.2']
189+
hosts = [self.host, '127.0.0.2']
199190
client = ParallelSSHClient(hosts,
200191
port=self.listen_port,
201192
pkey=self.user_key,
@@ -219,11 +210,11 @@ def test_pssh_client_hosts_list_part_failure(self):
219210
server2.kill()
220211

221212
def test_pssh_client_ssh_exception(self):
222-
listen_socket = make_socket('127.0.0.1')
213+
listen_socket = make_socket(self.host)
223214
listen_port = listen_socket.getsockname()[1]
224215
server = start_server(listen_socket,
225216
ssh_exception=True)
226-
client = ParallelSSHClient(['127.0.0.1'],
217+
client = ParallelSSHClient([self.host],
227218
user='fakey', password='fakey',
228219
port=listen_port,
229220
pkey=paramiko.RSAKey.generate(1024),
@@ -238,13 +229,13 @@ def test_pssh_client_ssh_exception(self):
238229
server.join()
239230

240231
def test_pssh_client_timeout(self):
241-
listen_socket = make_socket('127.0.0.1')
232+
listen_socket = make_socket(self.host)
242233
listen_port = listen_socket.getsockname()[1]
243234
server_timeout=0.2
244235
client_timeout=server_timeout-0.1
245236
server = start_server(listen_socket,
246237
timeout=server_timeout)
247-
client = ParallelSSHClient(['127.0.0.1'], port=listen_port,
238+
client = ParallelSSHClient([self.host], port=listen_port,
248239
pkey=self.user_key,
249240
timeout=client_timeout)
250241
output = client.run_command(self.fake_cmd)
@@ -258,7 +249,7 @@ def test_pssh_client_timeout(self):
258249
raise server.exception
259250
except gevent.Timeout:
260251
pass
261-
# chan_timeout = output['127.0.0.1']['channel'].gettimeout()
252+
# chan_timeout = output[self.host]['channel'].gettimeout()
262253
# self.assertEqual(client_timeout, chan_timeout,
263254
# msg="Channel timeout %s does not match requested timeout %s" %(
264255
# chan_timeout, client_timeout,))
@@ -268,34 +259,53 @@ def test_pssh_client_timeout(self):
268259
def test_pssh_client_exec_command_password(self):
269260
"""Test password authentication. Fake server accepts any password
270261
even empty string"""
271-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
262+
client = ParallelSSHClient([self.host], port=self.listen_port,
272263
password='')
273264
cmd = client.exec_command(self.fake_cmd)[0]
274265
output = client.get_stdout(cmd)
275-
expected = {'127.0.0.1' : {'exit_code' : 0}}
266+
expected = {self.host : {'exit_code' : 0}}
276267
self.assertEqual(expected, output,
277268
msg="Got unexpected command output - %s" % (output,))
278269
del client
279270

280271
def test_pssh_client_long_running_command(self):
281272
expected_lines = 5
282-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
273+
client = ParallelSSHClient([self.host], port=self.listen_port,
283274
pkey=self.user_key)
284275
cmd = client.exec_command(self.long_cmd(expected_lines))[0]
285276
output = client.get_stdout(cmd, return_buffers=True)
286-
self.assertTrue('127.0.0.1' in output, msg="Got no output for command")
287-
stdout = list(output['127.0.0.1']['stdout'])
277+
self.assertTrue(self.host in output, msg="Got no output for command")
278+
stdout = list(output[self.host]['stdout'])
288279
self.assertTrue(len(stdout) == expected_lines,
289280
msg="Expected %s lines of response, got %s" % (
290281
expected_lines, len(stdout)))
291282
del client
283+
284+
def test_pssh_client_long_running_command_exit_codes(self):
285+
expected_lines = 5
286+
client = ParallelSSHClient([self.host], port=self.listen_port,
287+
pkey=self.user_key)
288+
output = client.run_command(self.long_cmd(expected_lines))
289+
self.assertTrue(self.host in output, msg="Got no output for command")
290+
self.assertTrue(not output[self.host]['exit_code'],
291+
msg="Got exit code %s for still running cmd.." % (
292+
output[self.host]['exit_code'],))
293+
# Embedded server is also asynchronous and in the same thread
294+
# as our client so need to sleep for duration of server connection
295+
gevent.sleep(expected_lines)
296+
client.pool.join()
297+
client.get_exit_codes(output)
298+
self.assertTrue(output[self.host]['exit_code'] == 0,
299+
msg="Got non-zero exit code %s" % (
300+
output[self.host]['exit_code'],))
301+
del client
292302

293303
def test_pssh_client_retries(self):
294304
"""Test connection error retries"""
295-
listen_socket = make_socket('127.0.0.1')
305+
listen_socket = make_socket(self.host)
296306
listen_port = listen_socket.getsockname()[1]
297307
expected_num_tries = 2
298-
client = ParallelSSHClient(['127.0.0.1'], port=listen_port,
308+
client = ParallelSSHClient([self.host], port=listen_port,
299309
pkey=self.user_key,
300310
num_retries=expected_num_tries)
301311
self.assertRaises(ConnectionErrorException, client.run_command, 'blah')
@@ -320,7 +330,7 @@ def test_pssh_copy_file(self):
320330
test_file.close()
321331
server = start_server({ self.fake_cmd : self.fake_resp },
322332
self.listen_socket)
323-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
333+
client = ParallelSSHClient([self.host], port=self.listen_port,
324334
pkey=self.user_key)
325335
cmds = client.copy_file(local_filename, remote_filename)
326336
cmds[0].get()
@@ -361,7 +371,7 @@ def test_pssh_hosts_more_than_pool_size(self):
361371
server2_socket = make_socket('127.0.0.2', port=self.listen_port)
362372
server2_port = server2_socket.getsockname()[1]
363373
server2 = start_server(server2_socket)
364-
hosts = ['127.0.0.1', '127.0.0.2']
374+
hosts = [self.host, '127.0.0.2']
365375
client = ParallelSSHClient(hosts,
366376
port=self.listen_port,
367377
pkey=self.user_key,
@@ -390,15 +400,15 @@ def test_ssh_proxy(self):
390400
proxy_server_port = proxy_server_socket.getsockname()[1]
391401
# server = start_server(self.listen_socket)
392402
proxy_server = start_server(proxy_server_socket)
393-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
403+
client = ParallelSSHClient([self.host], port=self.listen_port,
394404
pkey=self.user_key,
395405
proxy_host='127.0.0.2',
396406
proxy_port=proxy_server_port
397407
)
398408
# gevent.sleep(1)
399409
# import ipdb; ipdb.set_trace()
400410
output = client.run_command(self.fake_cmd)
401-
stdout = list(output['127.0.0.1']['stdout'])
411+
stdout = list(output[self.host]['stdout'])
402412
expected_stdout = [self.fake_resp]
403413
self.assertEqual(expected_stdout, stdout,
404414
msg="Got unexpected stdout - %s, expected %s" %
@@ -409,10 +419,10 @@ def test_ssh_proxy(self):
409419

410420
def test_bash_variable_substitution(self):
411421
"""Test bash variables work correctly"""
412-
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
422+
client = ParallelSSHClient([self.host], port=self.listen_port,
413423
pkey=self.user_key)
414424
command = """for i in 1 2 3; do echo $i; done"""
415-
output = list(client.run_command(command)['127.0.0.1']['stdout'])
425+
output = list(client.run_command(command)[self.host]['stdout'])
416426
expected = ['1','2','3']
417427
self.assertEqual(output, expected,
418428
msg="Unexpected output from bash variable substitution %s - should be %s" % (

0 commit comments

Comments
 (0)