@@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket;
1717#[ cfg( windows) ]
1818use std:: os:: windows:: io:: RawSocket ;
1919use std:: ptr:: null_mut;
20- use std:: sync:: Once ;
2120use std:: sync:: { Arc , Mutex , MutexGuard } ;
21+ use std:: sync:: { Once , Weak } ;
2222use std:: time:: Duration ;
2323
2424mod channel;
@@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> {
7272}
7373
7474pub ( crate ) struct SessionHolder {
75+ outer : Weak < Mutex < SessionHolder > > ,
7576 sess : sys:: ssh_session ,
7677 callbacks : sys:: ssh_callbacks_struct ,
7778 auth_callback : Option < Box < dyn FnMut ( & str , bool , bool , Option < String > ) -> SshResult < String > > > ,
79+ channel_open_request_auth_agent_callback :
80+ Option < Box < dyn FnMut ( Channel ) -> RequestAuthAgentResult > > ,
7881}
7982unsafe impl Send for SessionHolder { }
8083
@@ -197,11 +200,16 @@ impl Session {
197200 channel_open_request_x11_function : None ,
198201 channel_open_request_auth_agent_function : None ,
199202 } ;
200- let sess = Arc :: new ( Mutex :: new ( SessionHolder {
201- sess,
202- callbacks,
203- auth_callback : None ,
204- } ) ) ;
203+ let sess = Arc :: new_cyclic ( |outer| {
204+ let outer = outer. clone ( ) ;
205+ Mutex :: new ( SessionHolder {
206+ outer,
207+ sess,
208+ callbacks,
209+ auth_callback : None ,
210+ channel_open_request_auth_agent_callback : None ,
211+ } )
212+ } ) ;
205213
206214 {
207215 let mut sess = sess. lock ( ) . unwrap ( ) ;
@@ -274,6 +282,66 @@ impl Session {
274282 }
275283 }
276284
285+ unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback (
286+ session : sys:: ssh_session ,
287+ userdata : * mut :: std:: os:: raw:: c_void ,
288+ ) -> sys:: ssh_channel {
289+ let result = std:: panic:: catch_unwind ( || -> SshResult < sys:: ssh_channel > {
290+ let sess: & mut SessionHolder = & mut * ( userdata as * mut SessionHolder ) ;
291+ assert ! (
292+ std:: ptr:: eq( session, sess. sess) ,
293+ "invalid callback invocation: session mismatch"
294+ ) ;
295+ let cb = sess
296+ . channel_open_request_auth_agent_callback
297+ . as_mut ( )
298+ . unwrap ( ) ;
299+ let chan = unsafe { sys:: ssh_channel_new ( session) } ;
300+ if chan. is_null ( ) {
301+ return Err ( sess
302+ . last_error ( )
303+ . unwrap_or_else ( || Error :: fatal ( "ssh_channel_new failed" ) ) ) ;
304+ }
305+ match cb ( Channel :: new ( & sess. outer . upgrade ( ) . unwrap ( ) , chan) ) {
306+ // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh
307+ // temporarily "borrows" it for an unspecified amount of time.
308+ // libssh is guaranteed to finish using it before returning from the outermost
309+ // libssh function call that triggered this callback. As such function call
310+ // always happens with Session locked and dropping Channel needs to lock the
311+ // session first, we can be sure that this *mut sys::ssh_channel_struct will not
312+ // be freed while libssh is still using it.
313+ RequestAuthAgentResult :: Accept => Ok ( chan) ,
314+ RequestAuthAgentResult :: Reject ( mut chan_obj) => {
315+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
316+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
317+ Err ( Error :: RequestDenied ( "request auth agent" . to_string ( ) ) )
318+ }
319+ RequestAuthAgentResult :: Err ( mut chan_obj, err) => {
320+ unsafe { sys:: ssh_channel_free ( chan_obj. chan_inner ) } ;
321+ chan_obj. chan_inner = std:: ptr:: null_mut ( ) ;
322+ Err ( err)
323+ }
324+ }
325+ } ) ;
326+ match result {
327+ Err ( err) => {
328+ eprintln ! (
329+ "Panic in channel open request auth agent callback: {:?}" ,
330+ err
331+ ) ;
332+ std:: ptr:: null_mut ( )
333+ }
334+ Ok ( Err ( err) ) => {
335+ eprintln ! (
336+ "Error in channel open request auth agent callback: {:#}" ,
337+ err
338+ ) ;
339+ std:: ptr:: null_mut ( )
340+ }
341+ Ok ( Ok ( chan) ) => chan,
342+ }
343+ }
344+
277345 /// Sets a callback that is used by libssh when it needs to prompt
278346 /// for the passphrase during public key authentication.
279347 /// This is NOT used for password or keyboard interactive authentication.
@@ -326,6 +394,32 @@ impl Session {
326394 sess. callbacks . auth_function = Some ( Self :: bridge_auth_callback) ;
327395 }
328396
397+ /// Sets a callback that is used by libssh when the remote side requests a new channel
398+ /// for SSH agent forwarding.
399+ /// The callback has the signature:
400+ ///
401+ /// ```no_run
402+ /// use libssh_rs::RequestAuthAgentResult;
403+ /// fn callback(channel: Channel) -> RequestAuthAgentResult {
404+ /// unimplemented!()
405+ /// }
406+ /// ```
407+ ///
408+ /// The callback should decide whether to allow the agent forward and if so, take ownership of
409+ /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or
410+ /// in case of an error, the callback should return the channel back as it is not possible to
411+ /// drop it in the callback.
412+ pub fn set_channel_open_request_auth_agent_callback < F > ( & self , callback : F )
413+ where
414+ F : FnMut ( Channel ) -> RequestAuthAgentResult + ' static ,
415+ {
416+ let mut sess = self . lock_session ( ) ;
417+ sess. channel_open_request_auth_agent_callback
418+ . replace ( Box :: new ( callback) ) ;
419+ sess. callbacks . channel_open_request_auth_agent_function =
420+ Some ( Self :: bridge_channel_open_request_auth_agent_callback) ;
421+ }
422+
329423 /// Create a new channel.
330424 /// Channels are used to handle I/O for commands and forwarded streams.
331425 pub fn new_channel ( & self ) -> SshResult < Channel > {
@@ -1421,6 +1515,12 @@ pub struct InteractiveAuthInfo {
14211515 pub prompts : Vec < InteractiveAuthPrompt > ,
14221516}
14231517
1518+ pub enum RequestAuthAgentResult {
1519+ Accept ,
1520+ Reject ( Channel ) ,
1521+ Err ( Channel , Error ) ,
1522+ }
1523+
14241524/// A utility function that will prompt the user for input
14251525/// via the console/tty.
14261526///
0 commit comments