@@ -110,7 +110,7 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
110110 return false , nil
111111 }
112112
113- if srv .ReverseUnixForwardingCallback == nil || ! srv . ReverseUnixForwardingCallback ( ctx , reqPayload . SocketPath ) {
113+ if srv .ReverseUnixForwardingCallback == nil {
114114 return false , []byte ("unix forwarding is disabled" )
115115 }
116116
@@ -123,26 +123,11 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
123123 return false , nil
124124 }
125125
126- // Create socket parent dir if not exists.
127- parentDir := filepath .Dir (addr )
128- err = os .MkdirAll (parentDir , 0700 )
129- if err != nil {
130- // TODO: log mkdir failure
131- return false , nil
132- }
133-
134- // Remove existing socket if it exists. We do not use os.Remove() here
135- // so that directories are kept. Note that it's possible that we will
136- // overwrite a regular file here. Both of these behaviors match OpenSSH,
137- // however, which is why we unlink.
138- err = unlink (addr )
139- if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
140- // TODO: log unlink failure
141- return false , nil
142- }
143-
144- ln , err := net .Listen ("unix" , addr )
126+ ln , err := srv .ReverseUnixForwardingCallback (ctx , addr )
145127 if err != nil {
128+ if errors .Is (err , ErrRejected ) {
129+ return false , []byte ("unix forwarding is disabled" )
130+ }
146131 // TODO: log unix listen failure
147132 return false , nil
148133 }
@@ -227,3 +212,31 @@ func unlink(path string) error {
227212 }
228213 }
229214}
215+
216+ // SimpleUnixReverseForwardingCallback provides a basic implementation for
217+ // ReverseUnixForwardingCallback. The parent directory will be created (with
218+ // os.MkdirAll), and existing files with the same name will be removed.
219+ func SimpleUnixReverseForwardingCallback (_ Context , socketPath string ) (net.Listener , error ) {
220+ // Create socket parent dir if not exists.
221+ parentDir := filepath .Dir (socketPath )
222+ err := os .MkdirAll (parentDir , 0700 )
223+ if err != nil {
224+ return nil , fmt .Errorf ("failed to create parent directory %q for socket %q: %w" , parentDir , socketPath , err )
225+ }
226+
227+ // Remove existing socket if it exists. We do not use os.Remove() here
228+ // so that directories are kept. Note that it's possible that we will
229+ // overwrite a regular file here. Both of these behaviors match OpenSSH,
230+ // however, which is why we unlink.
231+ err = unlink (socketPath )
232+ if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
233+ return nil , fmt .Errorf ("failed to remove existing file in socket path %q: %w" , socketPath , err )
234+ }
235+
236+ ln , err := net .Listen ("unix" , socketPath )
237+ if err != nil {
238+ return nil , fmt .Errorf ("failed to listen on unix socket %q: %w" , socketPath , err )
239+ }
240+
241+ return ln , err
242+ }
0 commit comments