Skip to content

Commit cdf23ac

Browse files
author
Dan
committed
Adjusted pool size to be minimum of requested pool size or number of hosts. Added test for pool size
1 parent c0e31d5 commit cdf23ac

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

pssh.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,29 +305,33 @@ def __init__(self, hosts,
305305
"""
306306
:param hosts: Hosts to connect to
307307
:type hosts: list(str)
308-
:param user: (Optional) User to login as. Defaults to logged in user or\
308+
:param user: (Optional) User to login as. Defaults to logged in user or \
309309
user from ~/.ssh/config or /etc/ssh/ssh_config if set
310310
:type user: str
311-
:param password: (Optional) Password to use for login. Defaults to\
311+
:param password: (Optional) Password to use for login. Defaults to \
312312
no password
313313
:type password: str
314-
:param port: (Optional) Port number to use for SSH connection. Defaults\
314+
:param port: (Optional) Port number to use for SSH connection. Defaults \
315315
to None which uses SSH default
316316
:type port: int
317317
:param pkey: (Optional) Client's private key to be used to connect with
318318
:type pkey: :mod:`paramiko.PKey`
319-
:param num_retries: (Optional) Number of retries for connection attempts\
319+
:param num_retries: (Optional) Number of retries for connection attempts \
320320
before the client gives up. Defaults to 3.
321321
:type num_retries: int
322-
:param timeout: (Optional) Number of seconds to timout connection attempts\
323-
before the client gives up. Defaults to 10.
322+
:param timeout: (Optional) Number of seconds to timout connection \
323+
attempts before the client gives up. Defaults to 10.
324324
:type timeout: int
325325
:param forward_ssh_agent: (Optional) Turn on SSH agent forwarding - \
326326
equivalent to `ssh -A` from the `ssh` command line utility. \
327327
Defaults to True if not set.
328328
:type forward_ssh_agent: bool
329329
:param pool_size: (Optional) Greenlet pool size. Controls on how many\
330-
hosts to execute tasks in parallel. Defaults to 10
330+
hosts to execute tasks in parallel. Defaults to number of hosts or 10, \
331+
whichever is lower. Pool size will be *equal to* number of hosts if number\
332+
of hosts is lower than the pool size specified as that would only \
333+
increase overhead with no benefits.
334+
331335
:type pool_size: int
332336
333337
**Example**
@@ -387,8 +391,8 @@ def __init__(self, hosts,
387391
388392
Connection is terminated.
389393
"""
390-
self.pool = gevent.pool.Pool(size=pool_size)
391-
self.pool_size = pool_size
394+
self.pool_size = len(hosts) if len(hosts) < pool_size else pool_size
395+
self.pool = gevent.pool.Pool(size=self.pool_size)
392396
self.hosts = hosts
393397
self.user = user
394398
self.password = password

tests/test_pssh_client.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
import unittest
2323
from pssh import ParallelSSHClient, UnknownHostException, \
24-
AuthenticationException, ConnectionErrorException
25-
from fake_server.fake_server import start_server, make_socket, logger as server_logger, \
26-
paramiko_logger
24+
AuthenticationException, ConnectionErrorException
25+
from fake_server.fake_server import start_server, make_socket, \
26+
logger as server_logger, paramiko_logger
2727
import random
2828
import logging
2929
import gevent
@@ -165,8 +165,9 @@ def test_pssh_client_run_long_command(self):
165165
output = client.run_command(self.long_running_cmd)
166166
self.assertTrue('127.0.0.1' in output, msg="Got no output for command")
167167
stdout = list(output['127.0.0.1']['stdout'])
168-
self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" %
169-
(expected_lines, len(stdout)))
168+
self.assertTrue(len(stdout) == expected_lines,
169+
msg="Expected %s lines of response, got %s" % (
170+
expected_lines, len(stdout)))
170171
del client
171172
server.kill()
172173

@@ -198,7 +199,8 @@ def test_pssh_client_timeout(self):
198199
gevent.sleep(0.5)
199200
cmd.get()
200201
if not server.exception:
201-
raise Exception("Expected gevent.Timeout from socket timeout, got none")
202+
raise Exception(
203+
"Expected gevent.Timeout from socket timeout, got none")
202204
raise server.exception
203205
except gevent.Timeout:
204206
pass
@@ -217,7 +219,7 @@ def test_pssh_client_exec_command_password(self):
217219
output = client.get_stdout(cmd)
218220
expected = {'127.0.0.1' : {'exit_code' : 0}}
219221
self.assertEqual(expected, output,
220-
msg = "Got unexpected command output - %s" % (output,))
222+
msg="Got unexpected command output - %s" % (output,))
221223
del client
222224
server.join()
223225

@@ -232,8 +234,9 @@ def test_pssh_client_long_running_command(self):
232234
output = client.get_stdout(cmd)
233235
self.assertTrue('127.0.0.1' in output, msg="Got no output for command")
234236
stdout = list(output['127.0.0.1']['stdout'])
235-
self.assertTrue(len(stdout) == expected_lines, msg="Expected %s lines of response, got %s" %
236-
(expected_lines, len(stdout)))
237+
self.assertTrue(len(stdout) == expected_lines,
238+
msg="Expected %s lines of response, got %s" % (
239+
expected_lines, len(stdout)))
237240
del client
238241
server.kill()
239242

@@ -242,7 +245,8 @@ def test_pssh_client_retries(self):
242245
expected_num_tries = 2
243246
with self.assertRaises(ConnectionErrorException) as cm:
244247
client = ParallelSSHClient(['127.0.0.1'], port=self.listen_port,
245-
pkey=self.user_key, num_retries=expected_num_tries)
248+
pkey=self.user_key,
249+
num_retries=expected_num_tries)
246250
cmd = client.exec_command('blah')[0]
247251
cmd.get()
248252
num_tries = cm.exception.args[-1:][0]
@@ -273,3 +277,24 @@ def test_pssh_copy_file(self):
273277
os.unlink(filepath)
274278
del client
275279
server.join()
280+
281+
def test_pssh_pool_size(self):
282+
"""Test pool size logic"""
283+
hosts = ['host-%01d' % d for d in xrange(5)]
284+
client = ParallelSSHClient(hosts)
285+
expected, actual = len(hosts), client.pool.size
286+
self.assertEqual(expected, actual,
287+
msg="Expected pool size to be %s, got %s" % (
288+
expected, actual,))
289+
hosts = ['host-%01d' % d for d in xrange(15)]
290+
client = ParallelSSHClient(hosts)
291+
expected, actual = client.pool_size, client.pool.size
292+
self.assertEqual(expected, actual,
293+
msg="Expected pool size to be %s, got %s" % (
294+
expected, actual,))
295+
hosts = ['host-%01d' % d for d in xrange(15)]
296+
client = ParallelSSHClient(hosts, pool_size=len(hosts)+5)
297+
expected, actual = len(hosts), client.pool.size
298+
self.assertEqual(expected, actual,
299+
msg="Expected pool size to be %s, got %s" % (
300+
expected, actual,))

0 commit comments

Comments
 (0)