|
1 | 1 | package ssh |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bufio" |
4 | 5 | "fmt" |
5 | 6 | "io/ioutil" |
6 | 7 | "os" |
| 8 | + "strings" |
7 | 9 |
|
| 10 | + "golang.org/x/crypto/ssh" |
8 | 11 | "golang.org/x/crypto/ssh/testdata" |
9 | 12 |
|
10 | 13 | . "gopkg.in/check.v1" |
11 | 14 | ) |
12 | 15 |
|
13 | | -type SuiteCommon struct{} |
| 16 | +type ( |
| 17 | + SuiteCommon struct{} |
| 18 | + |
| 19 | + mockKnownHosts struct{} |
| 20 | +) |
| 21 | + |
| 22 | +func (mockKnownHosts) host() string { return "github.com" } |
| 23 | +func (mockKnownHosts) knownHosts() []byte { |
| 24 | + return []byte(`github.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEAq2A7hRGmdnm9tUDbO9IDSwBK6TbQa+PXYPCPy6rbTrTtw7PHkccKrpp0yVhp5HdEIcKr6pLlVDBfOLX9QUsyCOV0wzfjIJNlGEYsdlLJizHhbn2mUjvSAHQqZETYP81eFzLQNnPHt4EVVUh7VfDESU84KezmD5QlWpXLmvU31/yMf+Se8xhHTvKSCZIFImWwoG6mbUoWf9nzpIoaSjB+weqqUUmpaaasXVal72J+UX2B+2RPW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ==`) |
| 25 | +} |
| 26 | +func (mockKnownHosts) Network() string { return "tcp" } |
| 27 | +func (mockKnownHosts) String() string { return "github.com:22" } |
14 | 28 |
|
15 | 29 | var _ = Suite(&SuiteCommon{}) |
16 | 30 |
|
@@ -149,3 +163,49 @@ func (*SuiteCommon) TestNewPublicKeysWithInvalidPEM(c *C) { |
149 | 163 | c.Assert(err, NotNil) |
150 | 164 | c.Assert(auth, IsNil) |
151 | 165 | } |
| 166 | + |
| 167 | +func (*SuiteCommon) TestNewKnownHostsCallback(c *C) { |
| 168 | + var mock = mockKnownHosts{} |
| 169 | + |
| 170 | + f, err := ioutil.TempFile("", "known-hosts") |
| 171 | + c.Assert(err, IsNil) |
| 172 | + |
| 173 | + _, err = f.Write(mock.knownHosts()) |
| 174 | + c.Assert(err, IsNil) |
| 175 | + |
| 176 | + err = f.Close() |
| 177 | + c.Assert(err, IsNil) |
| 178 | + |
| 179 | + defer os.RemoveAll(f.Name()) |
| 180 | + |
| 181 | + f, err = os.Open(f.Name()) |
| 182 | + c.Assert(err, IsNil) |
| 183 | + |
| 184 | + defer f.Close() |
| 185 | + |
| 186 | + var hostKey ssh.PublicKey |
| 187 | + scanner := bufio.NewScanner(f) |
| 188 | + for scanner.Scan() { |
| 189 | + fields := strings.Split(scanner.Text(), " ") |
| 190 | + if len(fields) != 3 { |
| 191 | + continue |
| 192 | + } |
| 193 | + if strings.Contains(fields[0], mock.host()) { |
| 194 | + var err error |
| 195 | + hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) |
| 196 | + if err != nil { |
| 197 | + c.Fatalf("error parsing %q: %v", fields[2], err) |
| 198 | + } |
| 199 | + break |
| 200 | + } |
| 201 | + } |
| 202 | + if hostKey == nil { |
| 203 | + c.Fatalf("no hostkey for %s", mock.host()) |
| 204 | + } |
| 205 | + |
| 206 | + clb, err := NewKnownHostsCallback(f.Name()) |
| 207 | + c.Assert(err, IsNil) |
| 208 | + |
| 209 | + err = clb(mock.String(), mock, hostKey) |
| 210 | + c.Assert(err, IsNil) |
| 211 | +} |
0 commit comments