@@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket;
1717#[ cfg( windows) ]
1818use std:: os:: windows:: io:: RawSocket ;
1919use std:: ptr:: null_mut;
20- use std:: sync:: Once ;
2120use std:: sync:: { Arc , Mutex , MutexGuard } ;
21+ use std:: sync:: { Once , Weak } ;
2222use std:: time:: Duration ;
2323
2424mod channel;
@@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> {
7272}
7373
7474pub ( crate ) struct SessionHolder {
75+ outer : Weak < Mutex < SessionHolder > > ,
7576 sess : sys:: ssh_session ,
7677 callbacks : sys:: ssh_callbacks_struct ,
7778 auth_callback : Option < Box < dyn FnMut ( & str , bool , bool , Option < String > ) -> SshResult < String > > > ,
79+ channel_open_request_auth_agent_callback :
80+ Option < Box < dyn FnMut ( Channel ) -> Result < ( ) , RequestAuthAgentError > > > ,
7881}
7982unsafe impl Send for SessionHolder { }
8083
@@ -197,11 +200,16 @@ impl Session {
197200 channel_open_request_x11_function : None ,
198201 channel_open_request_auth_agent_function : None ,
199202 } ;
200- let sess = Arc :: new ( Mutex :: new ( SessionHolder {
201- sess,
202- callbacks,
203- auth_callback : None ,
204- } ) ) ;
203+ let sess = Arc :: new_cyclic ( |outer| {
204+ let outer = outer. clone ( ) ;
205+ Mutex :: new ( SessionHolder {
206+ outer,
207+ sess,
208+ callbacks,
209+ auth_callback : None ,
210+ channel_open_request_auth_agent_callback : None ,
211+ } )
212+ } ) ;
205213
206214 {
207215 let mut sess = sess. lock ( ) . unwrap ( ) ;
@@ -274,6 +282,55 @@ impl Session {
274282 }
275283 }
276284
285+ unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback (
286+ session : sys:: ssh_session ,
287+ userdata : * mut :: std:: os:: raw:: c_void ,
288+ ) -> sys:: ssh_channel {
289+ let result = std:: panic:: catch_unwind ( || -> SshResult < sys:: ssh_channel > {
290+ let sess: & mut SessionHolder = & mut * ( userdata as * mut SessionHolder ) ;
291+ assert ! (
292+ std:: ptr:: eq( session, sess. sess) ,
293+ "invalid callback invocation: session mismatch"
294+ ) ;
295+ let cb = sess
296+ . channel_open_request_auth_agent_callback
297+ . as_mut ( )
298+ . unwrap ( ) ;
299+ let chan = unsafe { sys:: ssh_channel_new ( session) } ;
300+ if chan. is_null ( ) {
301+ return Err ( sess
302+ . last_error ( )
303+ . unwrap_or_else ( || Error :: fatal ( "ssh_channel_new failed" ) ) ) ;
304+ }
305+ match cb ( Channel :: new ( & sess. outer . upgrade ( ) . unwrap ( ) , chan) ) {
306+ // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh
307+ // temporarily "borrows" it for an unspecified amount of time.
308+ // libssh is guaranteed to finish using it before returning from the outermost
309+ // libssh function call that triggered this callback. As such function call
310+ // always happens with Session locked and dropping Channel needs to lock the
311+ // session first, we can be sure that this *mut sys::ssh_channel_struct will not
312+ // be freed while libssh is still using it.
313+ Ok ( _) => Ok ( chan) ,
314+ Err ( RequestAuthAgentError ( err, mut chan_obj) ) => {
315+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
316+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
317+ Err ( err)
318+ }
319+ }
320+ } ) ;
321+ match result {
322+ Err ( err) => {
323+ eprintln ! ( "Panic in request auth agent callback: {:?}" , err) ;
324+ std:: ptr:: null_mut ( )
325+ }
326+ Ok ( Err ( err) ) => {
327+ eprintln ! ( "Error in request auth agent callback: {:#}" , err) ;
328+ std:: ptr:: null_mut ( )
329+ }
330+ Ok ( Ok ( chan) ) => chan,
331+ }
332+ }
333+
277334 /// Sets a callback that is used by libssh when it needs to prompt
278335 /// for the passphrase during public key authentication.
279336 /// This is NOT used for password or keyboard interactive authentication.
@@ -326,6 +383,32 @@ impl Session {
326383 sess. callbacks . auth_function = Some ( Self :: bridge_auth_callback) ;
327384 }
328385
386+ /// Sets a callback that is used by libssh when the remote side requests a new channel
387+ /// for SSH agent forwarding.
388+ /// The callback has the signature:
389+ ///
390+ /// ```no_run
391+ /// use libssh_rs::RequestAuthAgentResult;
392+ /// fn callback(channel: Channel) -> RequestAuthAgentResult {
393+ /// unimplemented!()
394+ /// }
395+ /// ```
396+ ///
397+ /// The callback should decide whether to allow the agent forward and if so, take ownership of
398+ /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or
399+ /// in case of an error, the callback should return the channel back as it is not possible to
400+ /// drop it in the callback.
401+ pub fn set_channel_open_request_auth_agent_callback < F > ( & self , callback : F )
402+ where
403+ F : FnMut ( Channel ) -> Result < ( ) , RequestAuthAgentError > + ' static ,
404+ {
405+ let mut sess = self . lock_session ( ) ;
406+ sess. channel_open_request_auth_agent_callback
407+ . replace ( Box :: new ( callback) ) ;
408+ sess. callbacks . channel_open_request_auth_agent_function =
409+ Some ( Self :: bridge_channel_open_request_auth_agent_callback) ;
410+ }
411+
329412 /// Create a new channel.
330413 /// Channels are used to handle I/O for commands and forwarded streams.
331414 pub fn new_channel ( & self ) -> SshResult < Channel > {
0 commit comments