|
16 | 16 | # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA |
17 | 17 |
|
18 | 18 | import os |
| 19 | +import pytest |
19 | 20 | import shutil |
20 | 21 | import subprocess |
21 | 22 | import tempfile |
22 | 23 | from datetime import datetime |
23 | | -from hashlib import sha256 |
24 | | -from tempfile import NamedTemporaryFile |
25 | | -from unittest.mock import MagicMock, call, patch |
26 | | - |
27 | | -import pytest |
28 | 24 | from gevent import sleep, spawn, Timeout as GTimeout, socket |
| 25 | +from hashlib import sha256 |
29 | 26 | from pytest import raises |
30 | 27 | from ssh2.exceptions import (SocketDisconnectError, BannerRecvError, SocketRecvError, |
31 | 28 | AgentConnectionError, AgentListIdentitiesError, |
32 | 29 | AgentAuthenticationError, AgentGetIdentityError, SFTPProtocolError, |
33 | 30 | AuthenticationError as SSH2AuthenticationError, |
34 | 31 | ) |
35 | 32 | from ssh2.session import Session |
| 33 | +from tempfile import NamedTemporaryFile |
| 34 | +from unittest.mock import MagicMock, call, patch |
36 | 35 |
|
37 | 36 | from pssh.clients.native import SSHClient |
38 | 37 | from pssh.exceptions import (AuthenticationException, ConnectionErrorException, |
@@ -272,16 +271,14 @@ class _SSHClient(SSHClient): |
272 | 271 | pkey=self.user_key, |
273 | 272 | num_retries=1, |
274 | 273 | allow_agent=False) |
275 | | - client.disconnect() |
276 | 274 | client.pkey = None |
277 | | - del client.session |
278 | | - del client.sock |
| 275 | + client._disconnect() |
279 | 276 | client._connect(self.host, self.port) |
280 | 277 | client._init_session() |
281 | 278 | client.IDENTITIES = (self.user_key,) |
282 | 279 | # Default identities auth only should succeed |
283 | 280 | client._identity_auth() |
284 | | - client.disconnect() |
| 281 | + client._disconnect() |
285 | 282 | client._connect(self.host, self.port) |
286 | 283 | client._init_session() |
287 | 284 | # Auth should succeed |
@@ -360,9 +357,37 @@ def test_handshake_fail(self): |
360 | 357 | client = SSHClient(self.host, port=self.port, |
361 | 358 | pkey=self.user_key, |
362 | 359 | num_retries=1) |
363 | | - client.session.disconnect() |
| 360 | + client.eagain(client.session.disconnect) |
364 | 361 | self.assertRaises((SocketDisconnectError, BannerRecvError, SocketRecvError), client._init_session) |
365 | 362 |
|
| 363 | + @patch('gevent.socket.socket') |
| 364 | + @patch('pssh.clients.native.single.Session') |
| 365 | + def test_sock_shutdown_fail(self, mock_sess, mock_sock): |
| 366 | + sess = MagicMock() |
| 367 | + sock = MagicMock() |
| 368 | + mock_sess.return_value = sess |
| 369 | + mock_sock.return_value = sock |
| 370 | + |
| 371 | + hand_mock = MagicMock() |
| 372 | + sess.handshake = hand_mock |
| 373 | + retries = 2 |
| 374 | + client = SSHClient(self.host, port=self.port, |
| 375 | + num_retries=retries, |
| 376 | + timeout=.1, |
| 377 | + retry_delay=.1, |
| 378 | + _auth_thread_pool=False, |
| 379 | + allow_agent=False, |
| 380 | + ) |
| 381 | + self.assertIsInstance(client, SSHClient) |
| 382 | + hand_mock.side_effect = AuthenticationError |
| 383 | + sock.closed = False |
| 384 | + sock.detach = MagicMock() |
| 385 | + sock.detach.side_effect = Exception |
| 386 | + self.assertRaises(AuthenticationError, client._init_session) |
| 387 | + self.assertEqual(sock.detach.call_count, retries) |
| 388 | + client._disconnect() |
| 389 | + self.assertIsNone(client.sock) |
| 390 | + |
366 | 391 | def test_stdout_parsing(self): |
367 | 392 | dir_list = os.listdir(os.path.expanduser('~')) |
368 | 393 | host_out = self.client.run_command('ls -la') |
@@ -438,15 +463,30 @@ def test_multiple_clients_exec_terminates_channels(self): |
438 | 463 | # and break subsequent sessions even on different socket and |
439 | 464 | # session |
440 | 465 | def scope_killer(): |
441 | | - for _ in range(5): |
| 466 | + for _ in range(20): |
442 | 467 | client = SSHClient(self.host, port=self.port, |
443 | 468 | pkey=self.user_key, |
444 | 469 | num_retries=1, |
445 | 470 | allow_agent=False) |
446 | 471 | host_out = client.run_command(self.cmd) |
447 | 472 | output = list(host_out.stdout) |
448 | 473 | self.assertListEqual(output, [self.resp]) |
449 | | - client.disconnect() |
| 474 | + |
| 475 | + scope_killer() |
| 476 | + |
| 477 | + def test_multiple_clients_exec_terminates_channels_explicit_disc(self): |
| 478 | + # Explicit disconnects should not affect subsequent connections |
| 479 | + def scope_killer(): |
| 480 | + for _ in range(20): |
| 481 | + client = SSHClient(self.host, port=self.port, |
| 482 | + pkey=self.user_key, |
| 483 | + num_retries=1, |
| 484 | + allow_agent=False) |
| 485 | + host_out = client.run_command(self.cmd) |
| 486 | + output = list(host_out.stdout) |
| 487 | + self.assertListEqual(output, [self.resp]) |
| 488 | + client._disconnect() |
| 489 | + |
450 | 490 | scope_killer() |
451 | 491 |
|
452 | 492 | def test_agent_auth_exceptions(self): |
@@ -1036,7 +1076,8 @@ def _make_sftp(): |
1036 | 1076 | client._make_sftp_eagain = _make_sftp |
1037 | 1077 | self.assertRaises(SFTPError, client._make_sftp) |
1038 | 1078 |
|
1039 | | - def test_disconnect_exc(self): |
| 1079 | + @patch('pssh.clients.native.single.Session') |
| 1080 | + def test_disconnect_exc(self, mock_sess): |
1040 | 1081 | class DiscError(Exception): |
1041 | 1082 | pass |
1042 | 1083 |
|
@@ -1091,5 +1132,36 @@ def test_many_short_lived_commands(self): |
1091 | 1132 | duration = end.total_seconds() |
1092 | 1133 | self.assertTrue(duration < timeout * 0.9, msg=f"Duration of instant cmd is {duration}") |
1093 | 1134 |
|
1094 | | - # TODO |
1095 | | - # * read output callback |
| 1135 | + def test_output_client_scope(self): |
| 1136 | + """Output objects should keep client alive while they are in scope even if client is not.""" |
| 1137 | + def make_client_run(): |
| 1138 | + client = SSHClient(self.host, port=self.port, |
| 1139 | + pkey=self.user_key, |
| 1140 | + num_retries=1, |
| 1141 | + allow_agent=False, |
| 1142 | + ) |
| 1143 | + host_out = client.run_command("%s; exit 1" % (self.cmd,)) |
| 1144 | + return host_out |
| 1145 | + |
| 1146 | + output = make_client_run() |
| 1147 | + stdout = list(output.stdout) |
| 1148 | + self.assertListEqual(stdout, [self.resp]) |
| 1149 | + self.assertEqual(output.exit_code, 1) |
| 1150 | + |
| 1151 | + def test_output_client_scope_disconnect(self): |
| 1152 | + """Calling deprecated .disconnect on client that also goes out of scope should not break reading |
| 1153 | + any unread output.""" |
| 1154 | + def make_client_run(): |
| 1155 | + client = SSHClient(self.host, port=self.port, |
| 1156 | + pkey=self.user_key, |
| 1157 | + num_retries=1, |
| 1158 | + allow_agent=False, |
| 1159 | + ) |
| 1160 | + host_out = client.run_command("%s; exit 1" % (self.cmd,)) |
| 1161 | + client.disconnect() |
| 1162 | + return host_out |
| 1163 | + |
| 1164 | + output = make_client_run() |
| 1165 | + stdout = list(output.stdout) |
| 1166 | + self.assertListEqual(stdout, [self.resp]) |
| 1167 | + self.assertEqual(output.exit_code, 1) |
0 commit comments