Skip to content

Commit 2690b1a

Browse files
Shadospkittenis
authored andcommitted
Add per-host copy_args to parallel copy functions (#111)
Add per-host copy_args to parallel copy functions
1 parent b932902 commit 2690b1a

File tree

4 files changed

+185
-20
lines changed

4 files changed

+185
-20
lines changed

pssh/base_pssh.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def get_exit_code(self, host_output):
201201
return
202202
return self._get_exit_code(host_output.channel)
203203

204-
def copy_file(self, local_file, remote_file, recurse=False):
204+
def copy_file(self, local_file, remote_file, recurse=False, copy_args=None):
205205
"""Copy local file to remote file in parallel
206206
207207
This function returns a list of greenlets which can be
@@ -224,11 +224,19 @@ def copy_file(self, local_file, remote_file, recurse=False):
224224
:type remote_file: str
225225
:param recurse: Whether or not to descend into directories recursively.
226226
:type recurse: bool
227+
:param copy_args: (Optional) format local_file and remote_file strings
228+
with per-host arguments in ``copy_args``. ``copy_args`` length must
229+
equal length of host list -
230+
:py:class:`pssh.exceptions.HostArgumentException` is raised otherwise
231+
:type copy_args: tuple or list
232+
227233
:rtype: List(:py:class:`gevent.Greenlet`) of greenlets for remote copy
228234
commands
229235
230236
:raises: :py:class:`ValueError` when a directory is supplied to
231237
local_file and recurse is not set
238+
:raises: :py:class:`pssh.exceptions.HostArgumentException` on number of
239+
per-host copy arguments not equal to number of hosts
232240
:raises: :py:class:`IOError` on I/O errors writing files
233241
:raises: :py:class:`OSError` on OS errors like permission denied
234242
@@ -238,9 +246,21 @@ def copy_file(self, local_file, remote_file, recurse=False):
238246
created as long as permissions allow.
239247
240248
"""
241-
return [self.pool.spawn(self._copy_file, host, local_file, remote_file,
242-
{'recurse': recurse})
243-
for host in self.hosts]
249+
if copy_args:
250+
try:
251+
return [self.pool.spawn(self._copy_file, host,
252+
local_file % copy_args[host_i],
253+
remote_file % copy_args[host_i],
254+
{'recurse': recurse})
255+
for host_i, host in enumerate(self.hosts)]
256+
except IndexError:
257+
raise HostArgumentException(
258+
"Number of per-host copy arguments provided does not match "
259+
"number of hosts")
260+
else:
261+
return [self.pool.spawn(self._copy_file, host, local_file,
262+
remote_file, {'recurse': recurse})
263+
for host in self.hosts]
244264

245265
def _copy_file(self, host, local_file, remote_file, recurse=False):
246266
"""Make sftp client, copy file"""
@@ -249,7 +269,7 @@ def _copy_file(self, host, local_file, remote_file, recurse=False):
249269
recurse=recurse)
250270

251271
def copy_remote_file(self, remote_file, local_file, recurse=False,
252-
suffix_separator='_', **kwargs):
272+
suffix_separator='_', copy_args=None, **kwargs):
253273
"""Copy remote file(s) in parallel as
254274
<local_file><suffix_separator><host>
255275
@@ -281,13 +301,22 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
281301
filename and host, defaults to ``_``. For example, for a
282302
``local_file`` value of ``myfile`` and default separator the
283303
resulting filename will be ``myfile_myhost`` for the file from
284-
host ``myhost``
304+
host ``myhost``. ``suffix_separator`` has no meaning if
305+
``copy_args`` is provided
285306
:type suffix_separator: str
307+
:param copy_args: (Optional) format remote_file and local_file strings
308+
with per-host arguments in ``copy_args``. ``copy_args`` length must
309+
equal length of host list -
310+
:py:class:`pssh.exceptions.HostArgumentException` is raised otherwise
311+
:type copy_args: tuple or list
312+
286313
:rtype: list(:py:class:`gevent.Greenlet`) of greenlets for remote copy
287314
commands
288315
289316
:raises: :py:class:`ValueError` when a directory is supplied to
290317
local_file and recurse is not set
318+
:raises: :py:class:`pssh.exceptions.HostArgumentException` on number of
319+
per-host copy arguments not equal to number of hosts
291320
:raises: :py:class:`IOError` on I/O errors writing files
292321
:raises: :py:class:`OSError` on OS errors like permission denied
293322
@@ -300,17 +329,29 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
300329
filepath separated by ``suffix_separator``.
301330
302331
"""
303-
return [self.pool.spawn(
304-
self._copy_remote_file, host, remote_file,
305-
local_file, recurse, suffix_separator=suffix_separator,
306-
**kwargs)
307-
for host in self.hosts]
332+
if copy_args:
333+
try:
334+
return [self.pool.spawn(
335+
self._copy_remote_file, host,
336+
remote_file % copy_args[host_i],
337+
local_file % copy_args[host_i], {'recurse': recurse},
338+
**kwargs)
339+
for host_i, host in enumerate(self.hosts)]
340+
except IndexError:
341+
raise HostArgumentException(
342+
"Number of per-host copy arguments provided does not match "
343+
"number of hosts")
344+
else:
345+
return [self.pool.spawn(
346+
self._copy_remote_file, host, remote_file,
347+
suffix_separator.join([local_file, host]), recurse,
348+
**kwargs)
349+
for host in self.hosts]
308350

309351
def _copy_remote_file(self, host, remote_file, local_file, recurse,
310-
suffix_separator='_', **kwargs):
352+
**kwargs):
311353
"""Make sftp client, copy file to local"""
312-
file_w_suffix = suffix_separator.join([local_file, host])
313354
self._make_ssh_client(host)
314355
return self.host_clients[host].copy_remote_file(
315-
remote_file, file_w_suffix, recurse=recurse,
356+
remote_file, local_file, recurse=recurse,
316357
**kwargs)

pssh/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class SSHException(Exception):
4040

4141

4242
class HostArgumentException(Exception):
43-
"""Raised on errors with per-host command arguments"""
43+
"""Raised on errors with per-host arguments to parallel functions"""
4444
pass
4545

4646

pssh/pssh2_client.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _make_ssh_client(self, host):
312312
num_retries=self.num_retries, timeout=self.timeout,
313313
allow_agent=self.allow_agent, retry_delay=self.retry_delay)
314314

315-
def copy_file(self, local_file, remote_file, recurse=False):
315+
def copy_file(self, local_file, remote_file, recurse=False, copy_args=None):
316316
"""Copy local file to remote file in parallel
317317
318318
This function returns a list of greenlets which can be
@@ -335,12 +335,19 @@ def copy_file(self, local_file, remote_file, recurse=False):
335335
:type remote_file: str
336336
:param recurse: Whether or not to descend into directories recursively.
337337
:type recurse: bool
338+
:param copy_args: (Optional) format local_file and remote_file strings
339+
with per-host arguments in ``copy_args``. ``copy_args`` length must
340+
equal length of host list -
341+
:py:class:`pssh.exceptions.HostArgumentException` is raised otherwise
342+
:type copy_args: tuple or list
338343
339344
:rtype: list(:py:class:`gevent.Greenlet`) of greenlets for remote copy
340345
commands
341346
342347
:raises: :py:class:`ValueError` when a directory is supplied to
343348
local_file and recurse is not set
349+
:raises: :py:class:`pssh.exceptions.HostArgumentException` on number of
350+
per-host copy arguments not equal to number of hosts
344351
:raises: :py:class:`pss.exceptions.SFTPError` on SFTP initialisation
345352
errors
346353
:raises: :py:class:`pssh.exceptions.SFTPIOError` on I/O errors writing
@@ -354,10 +361,11 @@ def copy_file(self, local_file, remote_file, recurse=False):
354361
355362
"""
356363
return BaseParallelSSHClient.copy_file(
357-
self, local_file, remote_file, recurse=recurse)
364+
self, local_file, remote_file, recurse=recurse, copy_args=copy_args)
358365

359366
def copy_remote_file(self, remote_file, local_file, recurse=False,
360-
suffix_separator='_', encoding='utf-8'):
367+
suffix_separator='_', copy_args=None,
368+
encoding='utf-8'):
361369
"""Copy remote file(s) in parallel as
362370
<local_file><suffix_separator><host>
363371
@@ -389,8 +397,14 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
389397
filename and host, defaults to ``_``. For example, for a
390398
``local_file`` value of ``myfile`` and default separator the
391399
resulting filename will be ``myfile_myhost`` for the file from
392-
host ``myhost``
400+
host ``myhost``. ``suffix_separator`` has no meaning if
401+
``copy_args`` is provided
393402
:type suffix_separator: str
403+
:param copy_args: (Optional) format remote_file and local_file strings
404+
with per-host arguments in ``copy_args``. ``copy_args`` length must
405+
equal length of host list -
406+
:py:class:`pssh.exceptions.HostArgumentException` is raised otherwise
407+
:type copy_args: tuple or list
394408
:param encoding: Encoding to use for file paths.
395409
:type encoding: str
396410
@@ -399,6 +413,8 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
399413
400414
:raises: :py:class:`ValueError` when a directory is supplied to
401415
local_file and recurse is not set
416+
:raises: :py:class:`pssh.exceptions.HostArgumentException` on number of
417+
per-host copy arguments not equal to number of hosts
402418
:raises: :py:class:`pss.exceptions.SFTPError` on SFTP initialisation
403419
errors
404420
:raises: :py:class:`pssh.exceptions.SFTPIOError` on I/O errors reading
@@ -416,4 +432,5 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
416432
"""
417433
return BaseParallelSSHClient.copy_remote_file(
418434
self, remote_file, local_file, recurse=recurse,
419-
suffix_separator=suffix_separator, encoding=encoding)
435+
suffix_separator=suffix_separator, copy_args=copy_args,
436+
encoding=encoding)

tests/test_pssh_ssh2_client.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,60 @@ def test_pssh_copy_file(self):
370370
except OSError:
371371
pass
372372

373+
def test_pssh_copy_file_per_host_args(self):
374+
"""Test parallel copy file with per-host arguments"""
375+
host2, host3 = '127.0.0.6', '127.0.0.7'
376+
server2 = OpenSSHServer(host2)
377+
server3 = OpenSSHServer(host3)
378+
servers = [server2, server3]
379+
for server in servers:
380+
server.start_server()
381+
time.sleep(1)
382+
hosts = [self.host, host2, host3]
383+
384+
local_file_prefix = 'test_file_'
385+
remote_file_prefix = 'test_remote_'
386+
387+
copy_args = [dict(zip(('local_file', 'remote_file',),
388+
(local_file_prefix + str(i + 1),
389+
remote_file_prefix + str(i + 1),)
390+
))
391+
for i, _ in enumerate(hosts)]
392+
393+
test_file_data = 'test'
394+
for i, _ in enumerate(hosts):
395+
test_file = open(local_file_prefix + str(i + 1), 'w')
396+
test_file.writelines([test_file_data + os.linesep])
397+
test_file.close()
398+
399+
client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key,
400+
num_retries=2)
401+
greenlets = client.copy_file('%(local_file)s', '%(remote_file)s',
402+
copy_args=copy_args)
403+
gevent.joinall(greenlets)
404+
405+
self.assertRaises(HostArgumentException, client.copy_file,
406+
'%(local_file)s', '%(remote_file)s',
407+
copy_args=[copy_args[0]])
408+
try:
409+
for i, _ in enumerate(hosts):
410+
remote_file_abspath = os.path.expanduser(
411+
'~/' + remote_file_prefix + str(i + 1))
412+
self.assertTrue(os.path.isfile(remote_file_abspath))
413+
remote_file_data = open(
414+
remote_file_abspath, 'r').readlines()
415+
self.assertEqual(
416+
remote_file_data[0].strip(), test_file_data)
417+
except Exception:
418+
raise
419+
finally:
420+
for i, _ in enumerate(hosts):
421+
remote_file_abspath = os.path.expanduser(
422+
'~/' + remote_file_prefix + str(i + 1))
423+
local_file_path = local_file_prefix + str(i + 1)
424+
os.unlink(remote_file_abspath)
425+
os.unlink(local_file_path)
426+
373427
def test_pssh_client_directory_relative_path(self):
374428
"""Tests copying multiple directories with SSH client. Copy all the files from
375429
local directory to server, then make sure they are all present."""
@@ -581,6 +635,59 @@ def test_pssh_copy_remote_file(self):
581635
shutil.rmtree(new_local_copied_dir)
582636
shutil.rmtree(remote_test_path_abs)
583637

638+
def test_pssh_copy_remote_file_per_host_args(self):
639+
"""Test parallel remote copy file with per-host arguments"""
640+
host2, host3 = '127.0.0.10', '127.0.0.11'
641+
server2 = OpenSSHServer(host2)
642+
server3 = OpenSSHServer(host3)
643+
servers = [server2, server3]
644+
for server in servers:
645+
server.start_server()
646+
time.sleep(1)
647+
hosts = [self.host, host2, host3]
648+
649+
remote_file_prefix = 'test_file_'
650+
local_file_prefix = 'test_local_'
651+
652+
copy_args = [dict(zip(('remote_file', 'local_file',),
653+
(remote_file_prefix + str(i + 1),
654+
local_file_prefix + str(i + 1),)
655+
))
656+
for i, _ in enumerate(hosts)]
657+
658+
test_file_data = 'test'
659+
for i, _ in enumerate(hosts):
660+
remote_file_abspath = os.path.expanduser(
661+
'~/' + remote_file_prefix + str(i + 1))
662+
test_file = open(remote_file_abspath, 'w')
663+
test_file.writelines([test_file_data + os.linesep])
664+
test_file.close()
665+
666+
client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key,
667+
num_retries=2)
668+
greenlets = client.copy_remote_file('%(remote_file)s', '%(local_file)s',
669+
copy_args=copy_args)
670+
gevent.joinall(greenlets)
671+
672+
self.assertRaises(HostArgumentException, client.copy_remote_file,
673+
'%(remote_file)s', '%(local_file)s',
674+
copy_args=[copy_args[0]])
675+
try:
676+
for i, _ in enumerate(hosts):
677+
local_file_path = local_file_prefix + str(i + 1)
678+
self.assertTrue(os.path.isfile(local_file_path))
679+
local_file_data = open(local_file_path, 'r').readlines()
680+
self.assertEqual(local_file_data[0].strip(), test_file_data)
681+
except Exception:
682+
raise
683+
finally:
684+
for i, _ in enumerate(hosts):
685+
remote_file_abspath = os.path.expanduser(
686+
'~/' + remote_file_prefix + str(i + 1))
687+
local_file_path = local_file_prefix + str(i + 1)
688+
os.unlink(remote_file_abspath)
689+
os.unlink(local_file_path)
690+
584691
def test_pssh_pool_size(self):
585692
"""Test setting pool size to non default values"""
586693
hosts = ['host-%01d' % d for d in range(5)]

0 commit comments

Comments
 (0)