2121from warnings import warn
2222
2323from gevent import sleep , spawn , get_hub
24- from gevent .select import POLLIN , POLLOUT
2524from ssh2 .error_codes import LIBSSH2_ERROR_EAGAIN
2625from ssh2 .exceptions import SFTPHandleError , SFTPProtocolError , \
2726 Timeout as SSH2Timeout
@@ -163,11 +162,14 @@ def _connect_proxy(self, proxy_host, proxy_port, proxy_pkey,
163162 return proxy_local_port
164163
165164 def disconnect (self ):
166- """Disconnect session, close socket if needed."""
165+ """Attempt to disconnect session.
166+
167+ Any errors on calling disconnect are suppressed by this function.
168+ """
167169 self ._keepalive_greenlet = None
168170 if self .session is not None :
169171 try :
170- self ._eagain ( self . session . disconnect )
172+ self ._disconnect_eagain ( )
171173 except Exception :
172174 pass
173175 self .session = None
@@ -316,10 +318,13 @@ def close_channel(self, channel):
316318 def _eagain (self , func , * args , ** kwargs ):
317319 return self ._eagain_errcode (func , LIBSSH2_ERROR_EAGAIN , * args , ** kwargs )
318320
321+ def _make_sftp_eagain (self ):
322+ return self ._eagain (self .session .sftp_init )
323+
319324 def _make_sftp (self ):
320325 """Make SFTP client from open transport"""
321326 try :
322- sftp = self ._eagain ( self . session . sftp_init )
327+ sftp = self ._make_sftp_eagain ( )
323328 except Exception as ex :
324329 raise SFTPError (ex )
325330 return sftp
@@ -486,6 +491,27 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
486491 logger .info ("Copied local file %s from remote destination %s:%s" ,
487492 local_file , self .host , remote_file )
488493
494+ def _scp_recv_recursive (self , remote_file , local_file , sftp , encoding = 'utf-8' ):
495+ try :
496+ self ._eagain (sftp .stat , remote_file )
497+ except (SFTPHandleError , SFTPProtocolError ):
498+ msg = "Remote file or directory %s does not exist"
499+ logger .error (msg , remote_file )
500+ raise SCPError (msg , remote_file )
501+ try :
502+ dir_h = self ._sftp_openfh (sftp .opendir , remote_file )
503+ except SFTPError :
504+ # remote_file is not a dir, scp file
505+ return self .scp_recv (remote_file , local_file , encoding = encoding )
506+ try :
507+ os .makedirs (local_file )
508+ except OSError :
509+ pass
510+ file_list = self ._sftp_readdir (dir_h )
511+ return self ._scp_recv_dir (file_list , remote_file ,
512+ local_file , sftp ,
513+ encoding = encoding )
514+
489515 def scp_recv (self , remote_file , local_file , recurse = False , sftp = None ,
490516 encoding = 'utf-8' ):
491517 """Copy remote file to local host via SCP.
@@ -505,33 +531,13 @@ def scp_recv(self, remote_file, local_file, recurse=False, sftp=None,
505531 enabled.
506532 :type encoding: str
507533
508- :raises: :py:class:`pssh.exceptions.SCPError` when a directory is
509- supplied to ``local_file`` and ``recurse`` is not set.
510534 :raises: :py:class:`pssh.exceptions.SCPError` on errors copying file.
511535 :raises: :py:class:`IOError` on local file IO errors.
512536 :raises: :py:class:`OSError` on local OS errors like permission denied.
513537 """
514538 if recurse :
515539 sftp = self ._make_sftp () if sftp is None else sftp
516- try :
517- self ._eagain (sftp .stat , remote_file )
518- except (SFTPHandleError , SFTPProtocolError ):
519- msg = "Remote file or directory %s does not exist"
520- logger .error (msg , remote_file )
521- raise SCPError (msg , remote_file )
522- try :
523- dir_h = self ._sftp_openfh (sftp .opendir , remote_file )
524- except SFTPError :
525- pass
526- else :
527- try :
528- os .makedirs (local_file )
529- except OSError :
530- pass
531- file_list = self ._sftp_readdir (dir_h )
532- return self ._scp_recv_dir (file_list , remote_file ,
533- local_file , sftp ,
534- encoding = encoding )
540+ return self ._scp_recv_recursive (remote_file , local_file , sftp , encoding = encoding )
535541 elif local_file .endswith ('/' ):
536542 remote_filename = remote_file .rsplit ('/' )[- 1 ]
537543 local_file += remote_filename
@@ -561,11 +567,6 @@ def _scp_recv(self, remote_file, local_file):
561567 continue
562568 total += size
563569 local_fh .write (data )
564- if total != fileinfo .st_size :
565- msg = "Error copying data from remote file %s on host %s. " \
566- "Copied %s out of %s total bytes"
567- raise SCPError (msg , remote_file , self .host , total ,
568- fileinfo .st_size )
569570 finally :
570571 local_fh .close ()
571572 file_chan .close ()
@@ -690,16 +691,12 @@ def poll(self, timeout=None):
690691 Blocks current greenlet only if socket has pending read or write operations
691692 in the appropriate direction.
692693 """
693- timeout = self .timeout if timeout is None else timeout
694- directions = self .session .block_directions ()
695- if directions == 0 :
696- return
697- events = 0
698- if directions & LIBSSH2_SESSION_BLOCK_INBOUND :
699- events = POLLIN
700- if directions & LIBSSH2_SESSION_BLOCK_OUTBOUND :
701- events |= POLLOUT
702- self ._poll_socket (events , timeout = timeout )
694+ self ._poll_errcodes (
695+ self .session .block_directions ,
696+ LIBSSH2_SESSION_BLOCK_INBOUND ,
697+ LIBSSH2_SESSION_BLOCK_OUTBOUND ,
698+ timeout = timeout ,
699+ )
703700
704701 def _eagain_write (self , write_func , data , timeout = None ):
705702 """Write data with given write_func for an ssh2-python session while
0 commit comments