Skip to content
This repository was archived by the owner on Sep 11, 2020. It is now read-only.

Commit a4b12e4

Browse files
kuba--mcuadros
authored andcommitted
plumbing/transport: ssh check if list of known_hosts files is empty
Signed-off-by: kuba-- <kuba@sourced.tech>
1 parent d3cec13 commit a4b12e4

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

plumbing/transport/ssh/auth_method.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,23 @@ func (a *PublicKeysCallback) ClientConfig() (*ssh.ClientConfig, error) {
236236
// NewKnownHostsCallback returns ssh.HostKeyCallback based on a file based on a
237237
// known_hosts file. http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT
238238
//
239-
// If files is empty, the list of files will be read from the SSH_KNOWN_HOSTS
239+
// If list of files is empty, then it will be read from the SSH_KNOWN_HOSTS
240240
// environment variable, example:
241241
// /home/foo/custom_known_hosts_file:/etc/custom_known/hosts_file
242242
//
243243
// If SSH_KNOWN_HOSTS is not set the following file locations will be used:
244244
// ~/.ssh/known_hosts
245245
// /etc/ssh/ssh_known_hosts
246246
func NewKnownHostsCallback(files ...string) (ssh.HostKeyCallback, error) {
247-
files, err := getDefaultKnownHostsFiles()
248-
if err != nil {
249-
return nil, err
247+
var err error
248+
249+
if len(files) == 0 {
250+
if files, err = getDefaultKnownHostsFiles(); err != nil {
251+
return nil, err
252+
}
250253
}
251254

252-
files, err = filterKnownHostsFiles(files...)
253-
if err != nil {
255+
if files, err = filterKnownHostsFiles(files...); err != nil {
254256
return nil, err
255257
}
256258

plumbing/transport/ssh/auth_method_test.go

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
package ssh
22

33
import (
4+
"bufio"
45
"fmt"
56
"io/ioutil"
67
"os"
8+
"strings"
79

10+
"golang.org/x/crypto/ssh"
811
"golang.org/x/crypto/ssh/testdata"
912

1013
. "gopkg.in/check.v1"
1114
)
1215

13-
type SuiteCommon struct{}
16+
type (
17+
SuiteCommon struct{}
18+
19+
mockKnownHosts struct{}
20+
)
21+
22+
func (mockKnownHosts) host() string { return "github.com" }
23+
func (mockKnownHosts) knownHosts() []byte {
24+
return []byte(`github.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEAq2A7hRGmdnm9tUDbO9IDSwBK6TbQa+PXYPCPy6rbTrTtw7PHkccKrpp0yVhp5HdEIcKr6pLlVDBfOLX9QUsyCOV0wzfjIJNlGEYsdlLJizHhbn2mUjvSAHQqZETYP81eFzLQNnPHt4EVVUh7VfDESU84KezmD5QlWpXLmvU31/yMf+Se8xhHTvKSCZIFImWwoG6mbUoWf9nzpIoaSjB+weqqUUmpaaasXVal72J+UX2B+2RPW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ==`)
25+
}
26+
func (mockKnownHosts) Network() string { return "tcp" }
27+
func (mockKnownHosts) String() string { return "github.com:22" }
1428

1529
var _ = Suite(&SuiteCommon{})
1630

@@ -149,3 +163,49 @@ func (*SuiteCommon) TestNewPublicKeysWithInvalidPEM(c *C) {
149163
c.Assert(err, NotNil)
150164
c.Assert(auth, IsNil)
151165
}
166+
167+
func (*SuiteCommon) TestNewKnownHostsCallback(c *C) {
168+
var mock = mockKnownHosts{}
169+
170+
f, err := ioutil.TempFile("", "known-hosts")
171+
c.Assert(err, IsNil)
172+
173+
_, err = f.Write(mock.knownHosts())
174+
c.Assert(err, IsNil)
175+
176+
err = f.Close()
177+
c.Assert(err, IsNil)
178+
179+
defer os.RemoveAll(f.Name())
180+
181+
f, err = os.Open(f.Name())
182+
c.Assert(err, IsNil)
183+
184+
defer f.Close()
185+
186+
var hostKey ssh.PublicKey
187+
scanner := bufio.NewScanner(f)
188+
for scanner.Scan() {
189+
fields := strings.Split(scanner.Text(), " ")
190+
if len(fields) != 3 {
191+
continue
192+
}
193+
if strings.Contains(fields[0], mock.host()) {
194+
var err error
195+
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
196+
if err != nil {
197+
c.Fatalf("error parsing %q: %v", fields[2], err)
198+
}
199+
break
200+
}
201+
}
202+
if hostKey == nil {
203+
c.Fatalf("no hostkey for %s", mock.host())
204+
}
205+
206+
clb, err := NewKnownHostsCallback(f.Name())
207+
c.Assert(err, IsNil)
208+
209+
err = clb(mock.String(), mock, hostKey)
210+
c.Assert(err, IsNil)
211+
}

0 commit comments

Comments
 (0)