@@ -18,15 +18,24 @@ import (
1818 "gitlab.com/postgres-ai/database-lab/v3/pkg/models"
1919)
2020
21- // ConnectionString builds PostgreSQL connection string.
22- func ConnectionString (host , port , username , dbname , password string ) string {
23- return fmt .Sprintf ("host=%s port=%s user='%s' database='%s' password='%s'" , host , port , username , dbname , password )
24- }
25-
2621const (
2722 availableExtensions = "select name, default_version, coalesce(installed_version,'') from pg_available_extensions " +
2823 "where installed_version is not null"
24+
2925 availableLocales = "select datname, lower(datcollate), lower(datctype) from pg_catalog.pg_database"
26+
27+ availableDBsTemplate = `select datname from pg_catalog.pg_database
28+ where not datistemplate and has_database_privilege('%s', datname, 'CONNECT')`
29+
30+ // maxNumberVerifiedDBs defines the maximum number of databases to verify availability as a database source.
31+ // The DB source instance can contain a large number of databases, so the verification will take a long time.
32+ // Therefore, we introduced a limit on the maximum number of databases to check for suitability as a source.
33+ maxNumberVerifiedDBs = 5
34+ )
35+
36+ var (
37+ errExtensionWarning = errors .New ("extension warning found" )
38+ errLocaleWarning = errors .New ("locale warning found" )
3039)
3140
3241type extension struct {
@@ -41,37 +50,125 @@ type locale struct {
4150 ctype string
4251}
4352
53+ // ConnectionString builds PostgreSQL connection string.
54+ func ConnectionString (host , port , username , dbname , password string ) string {
55+ return fmt .Sprintf ("host=%s port=%s user='%s' database='%s' password='%s'" , host , port , username , dbname , password )
56+ }
57+
58+ // GetDatabaseListQuery provides the query to get the list of databases available for user.
59+ func GetDatabaseListQuery (username string ) string {
60+ return fmt .Sprintf (availableDBsTemplate , username )
61+ }
62+
4463// CheckSource checks the readiness of the source database to dump and restore processes.
4564func CheckSource (ctx context.Context , conf * models.ConnectionTest , imageContent * ImageContent ) (* models.TestConnection , error ) {
46- connStr := ConnectionString (conf .Host , conf .Port , conf .Username , conf .DBName , conf .Password )
65+ conn , tcResponse := checkConnection (ctx , conf , conf .DBName )
66+ if tcResponse != nil {
67+ return tcResponse , nil
68+ }
69+
70+ defer func () {
71+ if err := conn .Close (ctx ); err != nil {
72+ log .Dbg ("failed to close connection:" , err )
73+ }
74+ }()
75+
76+ dbList := conf .DBList
77+
78+ if len (dbList ) == 0 {
79+ dbSourceList , err := getDBList (ctx , conn , conf .Username )
80+ if err != nil {
81+ return nil , err
82+ }
83+
84+ dbList = dbSourceList
85+ }
86+
87+ if len (dbList ) > maxNumberVerifiedDBs {
88+ dbList = dbList [:maxNumberVerifiedDBs ]
89+ tcResponse = & models.TestConnection {
90+ Status : models .TCStatusNotice ,
91+ Result : models .TCResultUnverifiedDB ,
92+ Message : "Too many databases are supposed to be checked. Only the following databases have been verified: " +
93+ strings .Join (dbList , ", " ),
94+ }
95+ }
96+
97+ for _ , dbName := range dbList {
98+ dbConn , listTC := checkConnection (ctx , conf , dbName )
99+ if listTC != nil {
100+ return listTC , nil
101+ }
102+
103+ listTC , err := checkContent (ctx , dbConn , dbName , imageContent )
104+ if err != nil {
105+ return nil , err
106+ }
107+
108+ if listTC != nil {
109+ return listTC , nil
110+ }
111+ }
112+
113+ if tcResponse != nil {
114+ return tcResponse , nil
115+ }
116+
117+ return & models.TestConnection {
118+ Status : models .TCStatusOK ,
119+ Result : models .TCResultOK ,
120+ Message : models .TCMessageOK ,
121+ }, nil
122+ }
123+
124+ func getDBList (ctx context.Context , conn * pgx.Conn , dbUsername string ) ([]string , error ) {
125+ dbList := make ([]string , 0 )
126+
127+ rows , err := conn .Query (ctx , GetDatabaseListQuery (dbUsername ))
128+ if err != nil {
129+ return nil , fmt .Errorf ("failed to perform query listing databases: %w" , err )
130+ }
131+
132+ for rows .Next () {
133+ var dbName string
134+ if err := rows .Scan (& dbName ); err != nil {
135+ return nil , fmt .Errorf ("failed to scan next row in database list result set: %w" , err )
136+ }
137+
138+ dbList = append (dbList , dbName )
139+ }
140+
141+ return dbList , nil
142+ }
143+
144+ func checkConnection (ctx context.Context , conf * models.ConnectionTest , dbName string ) (* pgx.Conn , * models.TestConnection ) {
145+ connStr := ConnectionString (conf .Host , conf .Port , conf .Username , dbName , conf .Password )
47146
48147 conn , err := pgx .Connect (ctx , connStr )
49148 if err != nil {
50149 log .Dbg ("failed to test database connection:" , err )
51150
52- return & models.TestConnection {
151+ return nil , & models.TestConnection {
53152 Status : models .TCStatusError ,
54153 Result : models .TCResultConnectionError ,
55154 Message : err .Error (),
56- }, nil
57- }
58-
59- defer func () {
60- if err := conn .Close (ctx ); err != nil {
61- log .Dbg ("failed to close connection:" , err )
62155 }
63- }()
156+ }
64157
65158 var one int
66159
67160 if err := conn .QueryRow (ctx , "select 1" ).Scan (& one ); err != nil {
68- return & models.TestConnection {
161+ return nil , & models.TestConnection {
69162 Status : models .TCStatusError ,
70163 Result : models .TCResultConnectionError ,
71164 Message : err .Error (),
72- }, nil
165+ }
73166 }
74167
168+ return conn , nil
169+ }
170+
171+ func checkContent (ctx context.Context , conn * pgx.Conn , dbName string , imageContent * ImageContent ) (* models.TestConnection , error ) {
75172 if ! imageContent .IsReady () {
76173 return & models.TestConnection {
77174 Status : models .TCStatusNotice ,
@@ -82,26 +179,30 @@ func CheckSource(ctx context.Context, conf *models.ConnectionTest, imageContent
82179 }
83180
84181 if missing , unsupported , err := checkExtensions (ctx , conn , imageContent .Extensions ()); err != nil {
182+ if err != errExtensionWarning {
183+ return nil , fmt .Errorf ("failed to check database extensions: %w" , err )
184+ }
185+
85186 return & models.TestConnection {
86187 Status : models .TCStatusWarning ,
87188 Result : models .TCResultMissingExtension ,
88- Message : buildExtensionsWarningMessage (missing , unsupported ),
189+ Message : buildExtensionsWarningMessage (dbName , missing , unsupported ),
89190 }, nil
90191 }
91192
92193 if missing , err := checkLocales (ctx , conn , imageContent .Locales (), imageContent .Databases ()); err != nil {
194+ if err != errLocaleWarning {
195+ return nil , fmt .Errorf ("failed to check database locales: %w" , err )
196+ }
197+
93198 return & models.TestConnection {
94199 Status : models .TCStatusWarning ,
95200 Result : models .TCResultMissingLocale ,
96- Message : buildLocalesWarningMessage (missing ),
201+ Message : buildLocalesWarningMessage (dbName , missing ),
97202 }, nil
98203 }
99204
100- return & models.TestConnection {
101- Status : models .TCStatusOK ,
102- Result : models .TCResultOK ,
103- Message : models .TCMessageOK ,
104- }, nil
205+ return nil , nil
105206}
106207
107208func checkExtensions (ctx context.Context , conn * pgx.Conn , imageExtensions map [string ]string ) ([]extension , []extension , error ) {
@@ -140,7 +241,7 @@ func checkExtensions(ctx context.Context, conn *pgx.Conn, imageExtensions map[st
140241 }
141242
142243 if len (missingExtensions ) != 0 || len (unsupportedVersions ) != 0 {
143- return missingExtensions , unsupportedVersions , errors . New ( "extension warning found" )
244+ return missingExtensions , unsupportedVersions , errExtensionWarning
144245 }
145246
146247 return nil , nil , nil
@@ -158,19 +259,19 @@ func toCanonicalSemver(v string) string {
158259 return v
159260}
160261
161- func buildExtensionsWarningMessage (missingExtensions , unsupportedVersions []extension ) string {
262+ func buildExtensionsWarningMessage (dbName string , missingExtensions , unsupportedVersions []extension ) string {
162263 sb := & strings.Builder {}
163264
164265 if len (missingExtensions ) > 0 {
165- sb .WriteString ("There are missing extensions:" )
266+ sb .WriteString ("There are missing extensions in the \" " + dbName + " \" database :" )
166267
167268 formatExtensionList (sb , missingExtensions )
168269
169- sb .WriteRune ( '\n' )
270+ sb .WriteString ( ". \n " )
170271 }
171272
172273 if len (unsupportedVersions ) > 0 {
173- sb .WriteString ("There are extensions with an unsupported version:" )
274+ sb .WriteString ("There are extensions with an unsupported version in the \" " + dbName + " \" database :" )
174275
175276 formatExtensionList (sb , unsupportedVersions )
176277 }
@@ -225,17 +326,17 @@ func checkLocales(ctx context.Context, conn *pgx.Conn, imageLocales, databases m
225326 }
226327
227328 if len (missingLocales ) != 0 {
228- return missingLocales , errors . New ( "locale warning found" )
329+ return missingLocales , errLocaleWarning
229330 }
230331
231332 return nil , nil
232333}
233334
234- func buildLocalesWarningMessage (missingLocales []locale ) string {
335+ func buildLocalesWarningMessage (dbName string , missingLocales []locale ) string {
235336 sb := & strings.Builder {}
236337
237338 if length := len (missingLocales ); length > 0 {
238- sb .WriteString ("There are missing locales:" )
339+ sb .WriteString ("There are missing locales in the \" " + dbName + " \" database :" )
239340
240341 for i , missing := range missingLocales {
241342 sb .WriteString (fmt .Sprintf (" '%s' (collate: %s, ctype: %s)" , missing .name , missing .collate , missing .ctype ))
0 commit comments