@@ -5,6 +5,12 @@ package install
55
66import (
77 "fmt"
8+ "net/url"
9+ "path"
10+ "path/filepath"
11+ "runtime"
12+ "strings"
13+
814 "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
915 "github.com/microsoft/go-sqlcmd/internal/cmdparser"
1016 "github.com/microsoft/go-sqlcmd/internal/config"
@@ -15,10 +21,6 @@ import (
1521 "github.com/microsoft/go-sqlcmd/internal/secret"
1622 "github.com/microsoft/go-sqlcmd/internal/sql"
1723 "github.com/spf13/viper"
18- "net/url"
19- "path/filepath"
20- "runtime"
21- "strings"
2224)
2325
2426// MssqlBase provide base support for installing SQL Server and all of its
@@ -395,7 +397,8 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
395397
396398func (c * MssqlBase ) validateUsingUrlExists () {
397399 output := c .Cmd .Output ()
398- u , err := url .Parse (c .usingDatabaseUrl )
400+ databaseUrl := extractUrl (c .usingDatabaseUrl )
401+ u , err := url .Parse (databaseUrl )
399402 c .CheckErr (err )
400403
401404 if u .Scheme != "http" && u .Scheme != "https" {
@@ -406,9 +409,17 @@ func (c *MssqlBase) validateUsingUrlExists() {
406409 "%q is not a valid URL for --using flag" , c .usingDatabaseUrl )
407410 }
408411
412+ if u .Path == "" {
413+ output .FatalfWithHints (
414+ []string {
415+ "--using URL must have a path to .bak file" ,
416+ },
417+ "%q is not a valid URL for --using flag" , c .usingDatabaseUrl )
418+ }
419+
409420 // At the moment we only support attaching .bak files, but we should
410421 // support .bacpacs and .mdfs in the future
411- if _ , file := filepath .Split (c . usingDatabaseUrl ); filepath .Ext (file ) != ".bak" {
422+ if _ , file := filepath .Split (u . Path ); filepath .Ext (file ) != ".bak" {
412423 output .FatalfWithHints (
413424 []string {
414425 "--using file URL must be a .bak file" ,
@@ -417,7 +428,7 @@ func (c *MssqlBase) validateUsingUrlExists() {
417428 }
418429
419430 // Verify the url actually exists, and early exit if it doesn't
420- urlExists (c . usingDatabaseUrl , output )
431+ urlExists (databaseUrl , output )
421432}
422433
423434func (c * MssqlBase ) query (commandText string ) {
@@ -468,31 +479,73 @@ CHECK_POLICY=OFF`
468479 }
469480}
470481
482+ func getDbNameAsIdentifier (dbName string ) string {
483+ escapedDbNAme := strings .ReplaceAll (dbName , "'" , "''" )
484+ return strings .ReplaceAll (escapedDbNAme , "]" , "]]" )
485+ }
486+
487+ func getDbNameAsNonIdentifier (dbName string ) string {
488+ return strings .ReplaceAll (dbName , "]" , "]]" )
489+ }
490+
491+ //parseDbName returns the databaseName from --using arg
492+ // It sets database name to the specified database name
493+ // or in absence of it, it is set to the filename without
494+ // extension.
495+ func parseDbName (usingDbUrl string ) string {
496+ u , _ := url .Parse (usingDbUrl )
497+ dbToken := path .Base (u .Path )
498+ if dbToken != "." && dbToken != "/" {
499+ lastIdx := strings .LastIndex (dbToken , ".bak" )
500+ if lastIdx != - 1 {
501+ //Get file name without extension
502+ fileName := dbToken [0 :lastIdx ]
503+ lastIdx += 5
504+ if lastIdx >= len (dbToken ) {
505+ return fileName
506+ }
507+ //Return database name if it was specified
508+ return dbToken [lastIdx :]
509+ }
510+ }
511+ return ""
512+ }
513+
514+ func extractUrl (usingArg string ) string {
515+ urlEndIdx := strings .LastIndex (usingArg , ".bak" )
516+ if urlEndIdx != - 1 {
517+ return usingArg [0 :(urlEndIdx + 4 )]
518+ }
519+ return usingArg
520+ }
521+
471522func (c * MssqlBase ) downloadAndRestoreDb (
472523 controller * container.Controller ,
473524 containerId string ,
474525 userName string ,
475526) {
476527 output := c .Cmd .Output ()
528+ databaseName := parseDbName (c .usingDatabaseUrl )
529+ databaseUrl := extractUrl (c .usingDatabaseUrl )
477530
478- u , err := url .Parse (c .usingDatabaseUrl )
479- c .CheckErr (err )
480- _ , file := filepath .Split (c .usingDatabaseUrl )
481- fileNameWithNoExt := strings .TrimSuffix (file , filepath .Ext (file ))
531+ _ , file := filepath .Split (databaseUrl )
482532
483533 // Download file from URL into container
484- output .Infof ("Downloading %s from %s " , file , u . Hostname () )
534+ output .Infof ("Downloading %s" , file )
485535
486536 temporaryFolder := "/var/opt/mssql/backup"
487537
488538 controller .DownloadFile (
489539 containerId ,
490- c . usingDatabaseUrl ,
540+ databaseUrl ,
491541 temporaryFolder ,
492542 )
493543
494544 // Restore database from file
495- output .Infof ("Restoring database %s" , fileNameWithNoExt )
545+ output .Infof ("Restoring database %s" , databaseName )
546+
547+ dbNameAsIdentifier := getDbNameAsIdentifier (databaseName )
548+ dbNameAsNonIdentifier := getDbNameAsNonIdentifier (databaseName )
496549
497550 text := `SET NOCOUNT ON;
498551
@@ -535,12 +588,12 @@ WHERE IsPresent = 1
535588SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1)
536589EXEC(@sql)`
537590
538- c .query (fmt .Sprintf (text , temporaryFolder , file , fileNameWithNoExt , temporaryFolder , file ))
591+ c .query (fmt .Sprintf (text , temporaryFolder , file , dbNameAsIdentifier , temporaryFolder , file ))
539592
540593 alterDefaultDb := fmt .Sprintf (
541594 "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [%s]" ,
542595 userName ,
543- fileNameWithNoExt )
596+ dbNameAsNonIdentifier )
544597 c .query (alterDefaultDb )
545598}
546599
0 commit comments