Skip to content

Commit bc81728

Browse files
committed
Merge pull request #45 from Caid11/sftp_recursive
Allow recursive copy via SFTP on `copy_file` function - resolves #21
2 parents 9720304 + 8df29c9 commit bc81728

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

pssh/ssh_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,16 @@ def mkdir(self, sftp, directory):
287287
return self.mkdir(sftp, sub_dirs)
288288
return True
289289

290-
def copy_file(self, local_file, remote_file):
290+
def _copy_dir(self, local_dir, remote_dir):
291+
"""Call copy_file on every file in the specified directory, copying
292+
them to the specified remote directory."""
293+
file_list = os.listdir(local_dir)
294+
for file_name in file_list:
295+
local_path = os.path.join(local_dir, file_name)
296+
remote_path = os.path.join(remote_dir, file_name)
297+
self.copy_file(local_path, remote_path, recurse=True)
298+
299+
def copy_file(self, local_file, remote_file, recurse=False):
291300
"""Copy local file to host via SFTP/SCP
292301
293302
Copy is done natively using SFTP/SCP version 2 protocol, no scp command \
@@ -297,7 +306,17 @@ def copy_file(self, local_file, remote_file):
297306
:type local_file: str
298307
:param remote_file: Remote filepath on remote host to copy file to
299308
:type remote_file: str
309+
:param recurse: Whether or not to descend into directories recursively.
310+
:type recurse: bool
311+
312+
:raises: :mod:'ValueError' when a directory is supplied to local_file \
313+
and recurse is not set
300314
"""
315+
if os.path.isdir(local_file) and recurse:
316+
return self._copy_dir(local_file, remote_file)
317+
elif os.path.isdir(local_file) and not recurse:
318+
raise ValueError("Recurse must be true if local_file is a "
319+
"directory.")
301320
sftp = self._make_sftp()
302321
destination = [_dir for _dir in remote_file.split(os.path.sep)
303322
if _dir][:-1][0]

tests/test_ssh_client.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import gevent
2424
import socket
2525
import time
26+
import shutil
2627
import unittest
2728
from pssh import SSHClient, ParallelSSHClient, UnknownHostException, AuthenticationException,\
2829
logger, ConnectionErrorException, UnknownHostException, SSHException
@@ -141,6 +142,49 @@ def test_ssh_client_sftp(self):
141142
os.rmdir(dirpath)
142143
del client
143144

145+
def test_ssh_client_directory(self):
146+
"""Tests copying directories with SSH client. Copy all the files from
147+
local directory to server, then make sure they are all present."""
148+
test_file_data = 'test'
149+
local_test_path = 'directory_test'
150+
remote_test_path = 'directory_test_copied'
151+
os.mkdir(local_test_path)
152+
remote_file_paths = []
153+
for i in range(0, 10):
154+
local_file_path = os.path.join(local_test_path, 'foo' + str(i))
155+
remote_file_path = os.path.join(remote_test_path, 'foo' + str(i))
156+
remote_file_paths.append(remote_file_path)
157+
test_file = open(local_file_path, 'w')
158+
test_file.write(test_file_data)
159+
test_file.close()
160+
client = SSHClient(self.host, port=self.listen_port,
161+
pkey=self.user_key)
162+
client.copy_file(local_test_path, remote_test_path, recurse=True)
163+
for path in remote_file_paths:
164+
self.assertTrue(os.path.isfile(path))
165+
shutil.rmtree(local_test_path)
166+
shutil.rmtree(remote_test_path)
167+
168+
def test_ssh_client_directory_no_recurse(self):
169+
"""Tests copying directories with SSH client. Copy all the files from
170+
local directory to server, then make sure they are all present."""
171+
test_file_data = 'test'
172+
local_test_path = 'directory_test'
173+
remote_test_path = 'directory_test_copied'
174+
os.mkdir(local_test_path)
175+
remote_file_paths = []
176+
for i in range(0, 10):
177+
local_file_path = os.path.join(local_test_path, 'foo' + str(i))
178+
remote_file_path = os.path.join(remote_test_path, 'foo' + str(i))
179+
remote_file_paths.append(remote_file_path)
180+
test_file = open(local_file_path, 'w')
181+
test_file.write(test_file_data)
182+
test_file.close()
183+
client = SSHClient(self.host, port=self.listen_port,
184+
pkey=self.user_key)
185+
self.assertRaises(ValueError, client.copy_file, local_test_path, remote_test_path)
186+
shutil.rmtree(local_test_path)
187+
144188
def test_ssh_agent_authentication(self):
145189
"""Test authentication via SSH agent.
146190
Do not provide public key to use when creating SSHClient,

0 commit comments

Comments
 (0)