Skip to content

Commit 0bcb460

Browse files
Danpkittenis
authored andcommitted
Cleanups. Updated docstrings. Made pull file from remote test use sub-directories in the tree. Test cleanups.
1 parent e72b1f3 commit 0bcb460

File tree

3 files changed

+55
-31
lines changed

3 files changed

+55
-31
lines changed

pssh/pssh_client.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -845,18 +845,9 @@ def copy_file(self, local_file, remote_file, recurse=False):
845845

846846
def _copy_file(self, host, local_file, remote_file, recurse):
847847
"""Make sftp client, copy file"""
848-
if not host in self.host_clients or not self.host_clients[host]:
849-
_user, _port, _password, _pkey = self._get_host_config_values(host)
850-
self.host_clients[host] = SSHClient(
851-
host, user=_user, password=_password, port=_port, pkey=_pkey,
852-
forward_ssh_agent=self.forward_ssh_agent,
853-
num_retries=self.num_retries,
854-
timeout=self.timeout,
855-
proxy_host=self.proxy_host,
856-
proxy_port=self.proxy_port,
857-
agent=self.agent,
858-
channel_timeout=self.channel_timeout)
859-
return self.host_clients[host].copy_file(local_file, remote_file, recurse=recurse)
848+
self._make_ssh_client(host)
849+
return self.host_clients[host].copy_file(local_file, remote_file,
850+
recurse=recurse)
860851

861852
def copy_file_to_local(self, remote_file, local_file, recurse=False):
862853
"""Copy remote file to local file in parallel
@@ -883,9 +874,19 @@ def copy_file_to_local(self, remote_file, local_file, recurse=False):
883874

884875
def _copy_file_to_local(self, host, remote_file, local_file, recurse):
885876
"""Make sftp client, copy file to local"""
886-
if not self.host_clients[host]:
887-
self.host_clients[host] = SSHClient(host, user=self.user,
888-
password=self.password,
889-
port=self.port, pkey=self.pkey,
890-
forward_ssh_agent=self.forward_ssh_agent)
891-
return self.host_clients[host].copy_file_to_local(remote_file, '_'.join([local_file, host]), recurse=recurse)
877+
self._make_ssh_client(host)
878+
return self.host_clients[host].copy_file_to_local(
879+
remote_file, '_'.join([local_file, host]), recurse=recurse)
880+
881+
def _make_ssh_client(self, host):
882+
if not host in self.host_clients or not self.host_clients[host]:
883+
_user, _port, _password, _pkey = self._get_host_config_values(host)
884+
self.host_clients[host] = SSHClient(
885+
host, user=_user, password=_password, port=_port, pkey=_pkey,
886+
forward_ssh_agent=self.forward_ssh_agent,
887+
num_retries=self.num_retries,
888+
timeout=self.timeout,
889+
proxy_host=self.proxy_host,
890+
proxy_port=self.proxy_port,
891+
agent=self.agent,
892+
channel_timeout=self.channel_timeout)

pssh/ssh_client.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -397,38 +397,46 @@ def copy_file_to_local(self, remote_file, local_file, recurse=False):
397397
:param recurse: Whether or not to recursively copy directories.
398398
:type recurse: bool
399399
400-
:raises: :mod:'ValueError' when a directory is supplied to remote_file \
400+
:raises: :mod:`ValueError` when a directory is supplied to remote_file \
401401
and recurse is not set
402+
:raises: :mod:`OSError` on OS errors creating directories or file
403+
:raises: :mod:`IOError` on IO errors creating directories or file
402404
"""
403405
sftp = self._make_sftp()
404406
try:
405407
sftp.listdir(remote_file)
406-
remote_dir_exists = True
407-
except IOError or OSError:
408+
except (OSError, IOError):
408409
remote_dir_exists = False
410+
else:
411+
remote_dir_exists = True
409412
if remote_dir_exists and recurse:
410413
return self._copy_dir_to_local(remote_file, local_file)
411414
elif remote_dir_exists and not recurse:
412415
raise ValueError("Recurse must be true if remote_file is a "
413416
"directory.")
414417
destination = self._parent_path_split(local_file)
415-
if not os.path.exists(destination):
416-
try:
417-
os.makedirs(destination)
418-
except OSError:
419-
logger.error("Unable to create local directory structure.")
420-
raise
418+
self._make_local_dir(destination)
421419
try:
420+
import ipdb; ipdb.set_trace()
422421
sftp.get(remote_file, local_file)
423422
except Exception, error:
424423
logger.error("Error occured copying file %s from remote destination %s:%s - %s",
425424
local_file, self.host, remote_file, error)
425+
raise error
426426
else:
427427
logger.info("Copied local file %s from remote destination %s:%s",
428428
local_file, self.host, remote_file)
429429

430-
@staticmethod
431-
def _parent_path_split(file_path):
430+
def _make_local_dir(self, dirpath):
431+
if not os.path.exists(dirpath):
432+
try:
433+
os.makedirs(dirpath)
434+
except OSError:
435+
logger.error("Unable to create local directory structure for "
436+
"directory %s", dirpath)
437+
raise
438+
439+
def _parent_path_split(self, file_path):
432440
try:
433441
destination = [_dir for _dir in file_path.split(os.path.sep)
434442
if _dir][:-1][0]

tests/test_ssh_client.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def test_ssh_client_copy_remote_directory(self):
185185
"""Tests copying a remote directory to the localhost"""
186186
remote_test_directory = 'remote_test_dir'
187187
local_test_directory = 'local_test_dir'
188+
for path in [local_test_path, remote_test_path]:
189+
try:
190+
shutil.rmtree(path)
191+
except OSError:
192+
pass
188193
os.mkdir(remote_test_directory)
189194
test_files = []
190195
for i in range(0, 10):
@@ -263,11 +268,17 @@ def test_ssh_client_sftp_from_remote_directory(self):
263268
test_file_data = 'test'
264269
remote_test_path = 'directory_test_remote'
265270
local_test_path = 'directory_test_local'
271+
for path in [local_test_path, remote_test_path]:
272+
try:
273+
shutil.rmtree(path)
274+
except OSError:
275+
pass
266276
os.mkdir(remote_test_path)
277+
os.mkdir(os.path.join(remote_test_path, 'subdir'))
267278
local_file_paths = []
268279
for i in range(0, 10):
269-
remote_file_path = os.path.join(remote_test_path, 'foo' + str(i))
270-
local_file_path = os.path.join(local_test_path, 'foo' + str(i))
280+
remote_file_path = os.path.join(remote_test_path, 'subdir', 'foo' + str(i))
281+
local_file_path = os.path.join(local_test_path, 'subdir', 'foo' + str(i))
271282
local_file_paths.append(local_file_path)
272283
test_file = open(remote_file_path, 'w')
273284
test_file.write(test_file_data)
@@ -286,6 +297,10 @@ def test_ssh_client_remote_directory_no_recurse(self):
286297
test_file_data = 'test'
287298
remote_test_path = 'directory_test'
288299
local_test_path = 'directory_test_copied'
300+
try:
301+
shutil.rmtree(remote_test_path)
302+
except OSError:
303+
pass
289304
os.mkdir(remote_test_path)
290305
local_file_paths = []
291306
for i in range(0, 10):

0 commit comments

Comments
 (0)