2323 )
2424
2525import functools
26+ from typing import Any , Optional
2627
28+ import paramiko .pkey
2729import paramiko .ssh_exception
2830
2931from testinfra .backend import base
3234class IgnorePolicy (paramiko .MissingHostKeyPolicy ):
3335 """Policy for ignoring missing host key."""
3436
35- def missing_host_key (self , client , hostname , key ):
37+ def missing_host_key (
38+ self , client : paramiko .SSHClient , hostname : str , key : paramiko .pkey .PKey
39+ ) -> None :
3640 pass
3741
3842
@@ -41,12 +45,12 @@ class ParamikoBackend(base.BaseBackend):
4145
4246 def __init__ (
4347 self ,
44- hostspec ,
45- ssh_config = None ,
46- ssh_identity_file = None ,
47- timeout = 10 ,
48- * args ,
49- ** kwargs ,
48+ hostspec : str ,
49+ ssh_config : Optional [ str ] = None ,
50+ ssh_identity_file : Optional [ str ] = None ,
51+ timeout : int = 10 ,
52+ * args : Any ,
53+ ** kwargs : Any ,
5054 ):
5155 self .host = self .parse_hostspec (hostspec )
5256 self .ssh_config = ssh_config
@@ -55,7 +59,13 @@ def __init__(
5559 self .timeout = int (timeout )
5660 super ().__init__ (self .host .name , * args , ** kwargs )
5761
58- def _load_ssh_config (self , client , cfg , ssh_config , ssh_config_dir = "~/.ssh" ):
62+ def _load_ssh_config (
63+ self ,
64+ client : paramiko .SSHClient ,
65+ cfg : dict [str , Any ],
66+ ssh_config : paramiko .SSHConfig ,
67+ ssh_config_dir : str = "~/.ssh" ,
68+ ) -> None :
5969 for key , value in ssh_config .lookup (self .host .name ).items ():
6070 if key == "hostname" :
6171 cfg [key ] = value
@@ -85,7 +95,7 @@ def _load_ssh_config(self, client, cfg, ssh_config, ssh_config_dir="~/.ssh"):
8595 self ._load_ssh_config (client , cfg , new_ssh_config , ssh_config_dir )
8696
8797 @functools .cached_property
88- def client (self ):
98+ def client (self ) -> paramiko . SSHClient :
8999 client = paramiko .SSHClient ()
90100 client .set_missing_host_key_policy (paramiko .WarningPolicy ())
91101 cfg = {
@@ -118,11 +128,13 @@ def client(self):
118128
119129 if self .ssh_identity_file :
120130 cfg ["key_filename" ] = self .ssh_identity_file
121- client .connect (** cfg )
131+ client .connect (** cfg ) # type: ignore[arg-type]
122132 return client
123133
124- def _exec_command (self , command ):
125- chan = self .client .get_transport ().open_session ()
134+ def _exec_command (self , command : bytes ) -> tuple [int , bytes , bytes ]:
135+ transport = self .client .get_transport ()
136+ assert transport is not None
137+ chan = transport .open_session ()
126138 if self .get_pty :
127139 chan .get_pty ()
128140 chan .exec_command (command )
@@ -131,17 +143,19 @@ def _exec_command(self, command):
131143 stderr = b"" .join (chan .makefile_stderr ("rb" ))
132144 return rc , stdout , stderr
133145
134- def run (self , command , * args , ** kwargs ) :
146+ def run (self , command : str , * args : str , ** kwargs : Any ) -> base . CommandResult :
135147 command = self .get_command (command , * args )
136- command = self .encode (command )
148+ cmd = self .encode (command )
137149 try :
138- rc , stdout , stderr = self ._exec_command (command )
150+ rc , stdout , stderr = self ._exec_command (cmd )
139151 except paramiko .ssh_exception .SSHException :
140- if not self .client .get_transport ().is_active ():
152+ transport = self .client .get_transport ()
153+ assert transport is not None
154+ if not transport .is_active ():
141155 # try to reinit connection (once)
142156 del self .client
143- rc , stdout , stderr = self ._exec_command (command )
157+ rc , stdout , stderr = self ._exec_command (cmd )
144158 else :
145159 raise
146160
147- return self .result (rc , command , stdout , stderr )
161+ return self .result (rc , cmd , stdout , stderr )
0 commit comments