|
9 | 9 | package mysql |
10 | 10 |
|
11 | 11 | import ( |
| 12 | + "bytes" |
12 | 13 | "crypto/tls" |
13 | 14 | "database/sql" |
14 | 15 | "database/sql/driver" |
15 | 16 | "fmt" |
16 | 17 | "io" |
17 | 18 | "io/ioutil" |
| 19 | + "log" |
18 | 20 | "net" |
19 | 21 | "net/url" |
20 | 22 | "os" |
@@ -1018,7 +1020,7 @@ func TestFoundRows(t *testing.T) { |
1018 | 1020 |
|
1019 | 1021 | func TestStrict(t *testing.T) { |
1020 | 1022 | // ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors |
1021 | | - relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES" |
| 1023 | + relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'" |
1022 | 1024 | // make sure the MySQL version is recent enough with a separate connection |
1023 | 1025 | // before running the test |
1024 | 1026 | conn, err := MySQLDriver{}.Open(relaxedDsn) |
@@ -1643,7 +1645,7 @@ func TestSqlInjection(t *testing.T) { |
1643 | 1645 |
|
1644 | 1646 | dsns := []string{ |
1645 | 1647 | dsn, |
1646 | | - dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", |
| 1648 | + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", |
1647 | 1649 | } |
1648 | 1650 | for _, testdsn := range dsns { |
1649 | 1651 | runTests(t, testdsn, createTest("1 OR 1=1")) |
@@ -1673,9 +1675,56 @@ func TestInsertRetrieveEscapedData(t *testing.T) { |
1673 | 1675 |
|
1674 | 1676 | dsns := []string{ |
1675 | 1677 | dsn, |
1676 | | - dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", |
| 1678 | + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", |
1677 | 1679 | } |
1678 | 1680 | for _, testdsn := range dsns { |
1679 | 1681 | runTests(t, testdsn, testData) |
1680 | 1682 | } |
1681 | 1683 | } |
| 1684 | + |
| 1685 | +func TestUnixSocketAuthFail(t *testing.T) { |
| 1686 | + runTests(t, dsn, func(dbt *DBTest) { |
| 1687 | + // Save the current logger so we can restore it. |
| 1688 | + oldLogger := errLog |
| 1689 | + |
| 1690 | + // Set a new logger so we can capture its output. |
| 1691 | + buffer := bytes.NewBuffer(make([]byte, 0, 64)) |
| 1692 | + newLogger := log.New(buffer, "prefix: ", 0) |
| 1693 | + SetLogger(newLogger) |
| 1694 | + |
| 1695 | + // Restore the logger. |
| 1696 | + defer SetLogger(oldLogger) |
| 1697 | + |
| 1698 | + // Make a new DSN that uses the MySQL socket file and a bad password, which |
| 1699 | + // we can make by simply appending any character to the real password. |
| 1700 | + badPass := pass + "x" |
| 1701 | + socket := "" |
| 1702 | + if prot == "unix" { |
| 1703 | + socket = addr |
| 1704 | + } else { |
| 1705 | + // Get socket file from MySQL. |
| 1706 | + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) |
| 1707 | + if err != nil { |
| 1708 | + t.Fatalf("Error on SELECT @@socket: %s", err.Error()) |
| 1709 | + } |
| 1710 | + } |
| 1711 | + t.Logf("socket: %s", socket) |
| 1712 | + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname) |
| 1713 | + db, err := sql.Open("mysql", badDSN) |
| 1714 | + if err != nil { |
| 1715 | + t.Fatalf("Error connecting: %s", err.Error()) |
| 1716 | + } |
| 1717 | + defer db.Close() |
| 1718 | + |
| 1719 | + // Connect to MySQL for real. This will cause an auth failure. |
| 1720 | + err = db.Ping() |
| 1721 | + if err == nil { |
| 1722 | + t.Error("expected Ping() to return an error") |
| 1723 | + } |
| 1724 | + |
| 1725 | + // The driver should not log anything. |
| 1726 | + if actual := buffer.String(); actual != "" { |
| 1727 | + t.Errorf("expected no output, got %q", actual) |
| 1728 | + } |
| 1729 | + }) |
| 1730 | +} |
0 commit comments