@@ -86,3 +86,94 @@ func TestMySQLDbOverwriteUsersAndGrantsData(t *testing.T) {
8686
8787 rd .Close ()
8888}
89+
90+ func TestMatchesHostPattern (t * testing.T ) {
91+ tests := []struct {
92+ name string
93+ host string
94+ pattern string
95+ expected bool
96+ }{
97+ // Basic wildcard patterns
98+ {"IP wildcard - exact match" , "127.0.0.1" , "127.0.0.%" , true },
99+ {"IP wildcard - different last octet" , "127.0.0.255" , "127.0.0.%" , true },
100+ {"IP wildcard - no match" , "192.168.1.1" , "127.0.0.%" , false },
101+ {"IP wildcard - partial match" , "127.0.1.1" , "127.0.0.%" , false },
102+
103+ // Multiple wildcards
104+ {"Multiple wildcards" , "192.168.1.100" , "192.168.%.%" , true },
105+ {"Multiple wildcards - no match" , "10.0.1.100" , "192.168.%.%" , false },
106+
107+ // Single wildcard at different positions
108+ {"Wildcard first octet" , "10.0.0.1" , "%.0.0.1" , true },
109+ {"Wildcard middle octet" , "192.168.50.1" , "192.%.50.1" , true },
110+ {"Wildcard last octet" , "192.168.1.255" , "192.168.1.%" , true },
111+
112+ // Non-IP patterns
113+ {"Hostname wildcard" , "server1.example.com" , "server%.example.com" , true },
114+ {"Hostname wildcard - no match" , "db1.example.com" , "server%.example.com" , false },
115+ {"Domain wildcard" , "host.subdomain.example.com" , "%.example.com" , true },
116+
117+ // Edge cases
118+ {"Empty pattern" , "127.0.0.1" , "" , false },
119+ {"Pattern without wildcard" , "127.0.0.1" , "127.0.0.1" , false }, // Should return false as it's not a wildcard pattern
120+ {"Just wildcard" , "anything" , "%" , true },
121+ {"Multiple wildcards together" , "test" , "%%" , true },
122+
123+ // Special characters in patterns (should be escaped)
124+ {"Pattern with dots" , "test.host" , "test.%" , true },
125+ {"Pattern with regex chars" , "test[1]" , "test[%]" , true },
126+ }
127+
128+ for _ , tt := range tests {
129+ t .Run (tt .name , func (t * testing.T ) {
130+ result := matchesHostPattern (tt .host , tt .pattern )
131+ require .Equal (t , tt .expected , result , "matchesHostPattern(%q, %q) = %v, want %v" , tt .host , tt .pattern , result , tt .expected )
132+ })
133+ }
134+ }
135+
136+ func TestGetUserWithWildcardAuthentication (t * testing.T ) {
137+ ctx := sql .NewEmptyContext ()
138+ db := CreateEmptyMySQLDb ()
139+ p := & capturingPersistence {}
140+ db .SetPersister (p )
141+
142+ // Add test users with various host patterns
143+ ed := db .Editor ()
144+ db .AddSuperUser (ed , "testuser" , "127.0.0.1" , "password" )
145+ db .AddSuperUser (ed , "localhost_user" , "localhost" , "password" )
146+ db .Persist (ctx , ed )
147+ ed .Close ()
148+
149+ rd := db .Reader ()
150+ defer rd .Close ()
151+
152+ tests := []struct {
153+ name string
154+ username string
155+ host string
156+ expectedUser string
157+ shouldFind bool
158+ }{
159+ // Test specific IP matching (existing functionality)
160+ {"Specific IP - exact match" , "testuser" , "127.0.0.1" , "testuser" , true },
161+ {"Localhost user - normalized" , "localhost_user" , "127.0.0.1" , "localhost_user" , true },
162+ {"Localhost user - ::1" , "localhost_user" , "::1" , "localhost_user" , true },
163+ {"Non-existent user" , "nonexistent" , "127.0.0.1" , "" , false },
164+ }
165+
166+ for _ , tt := range tests {
167+ t .Run (tt .name , func (t * testing.T ) {
168+ user := db .GetUser (rd , tt .username , tt .host , false )
169+
170+ if ! tt .shouldFind {
171+ require .Nil (t , user , "Expected no user to be found for %s@%s" , tt .username , tt .host )
172+ return
173+ }
174+
175+ require .NotNil (t , user , "Expected user to be found for %s@%s" , tt .username , tt .host )
176+ require .Equal (t , tt .expectedUser , user .User , "Expected username %s, got %s" , tt .expectedUser , user .User )
177+ })
178+ }
179+ }
0 commit comments