Skip to content

Commit 2d1dd6c

Browse files
committed
Override agent before call to connect
1 parent 8a7cfe3 commit 2d1dd6c

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pssh.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class SSHClient(object):
6969

7070
def __init__(self, host,
7171
user=None, password=None, port=None,
72-
pkey=None, forward_ssh_agent=True):
72+
pkey=None, forward_ssh_agent=True,
73+
_agent=None):
7374
"""Connect to host honouring any user set configuration in ~/.ssh/config \
7475
or /etc/ssh/ssh_config
7576
@@ -90,6 +91,11 @@ def __init__(self, host,
9091
equivalent to `ssh -A` from the `ssh` command line utility.
9192
Defaults to True if not set.
9293
:type forward_ssh_agent: bool
94+
:param _agent: (Optional) Override SSH agent object with the provided. \
95+
This allows for overriding of the default paramiko behaviour of \
96+
connecting to local SSH agent to lookup keys with our own SSH agent.
97+
Only really useful for testing, hence the internal variable prefix.
98+
:type _agent: :mod:`paramiko.agent.Agent`
9399
:raises: :mod:`pssh.AuthenticationException` on authentication error
94100
:raises: :mod:`pssh.UnknownHostException` on DNS resolution error
95101
:raises: :mod:`pssh.ConnectionErrorException` on error connecting"""
@@ -120,6 +126,8 @@ def __init__(self, host,
120126
self.pkey = pkey
121127
self.port = port if port else 22
122128
self.host = resolved_address
129+
if _agent:
130+
self.client._agent = _agent
123131
self._connect()
124132

125133
def _connect(self, retries=1):

tests/test_ssh_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def test_ssh_agent_authentication(self):
8585
agent.add_key(USER_KEY)
8686
server = start_server({ self.fake_cmd : self.fake_resp },
8787
self.listen_socket)
88-
client = SSHClient('127.0.0.1', port=self.listen_port)
89-
client.client._agent = agent
88+
client = SSHClient('127.0.0.1', port=self.listen_port,
89+
_agent=agent)
9090
channel, host, _stdout, _stderr = client.exec_command(self.fake_cmd)
9191
output = (line.strip() for line in _stdout)
9292
channel.close()

0 commit comments

Comments
 (0)