Skip to content

Commit bc02ce7

Browse files
committed
Refactor
1 parent 9b5a054 commit bc02ce7

File tree

29 files changed

+1021
-526
lines changed

29 files changed

+1021
-526
lines changed

cmd/modern/root/install/mssql-base.go

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ import (
99
"fmt"
1010
"github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency"
1111
"github.com/microsoft/go-sqlcmd/internal/tools"
12+
"github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest"
1213
"os"
1314
"runtime"
1415
"strings"
1516

16-
"github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer"
17-
1817
"github.com/microsoft/go-sqlcmd/cmd/modern/root/open"
1918

2019
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
@@ -59,9 +58,11 @@ type MssqlBase struct {
5958

6059
port int
6160

62-
usingDatabaseUrl string
63-
openTool string
64-
openFile string
61+
useDatabaseUrl string
62+
useMechanism string
63+
64+
openTool string
65+
openFile string
6566

6667
unitTesting bool
6768

@@ -219,12 +220,26 @@ func (c *MssqlBase) AddFlags(
219220
})
220221

221222
addFlag(cmdparser.FlagOptions{
222-
String: &c.usingDatabaseUrl,
223+
String: &c.useDatabaseUrl,
223224
DefaultString: "",
224225
Name: "using",
225226
Usage: "Download and use database from .bak/.bacpac/.mdf/.7z URL",
226227
})
227228

229+
addFlag(cmdparser.FlagOptions{
230+
String: &c.useDatabaseUrl,
231+
DefaultString: "",
232+
Name: "use",
233+
Usage: "Download and use database from .bak/.bacpac/.mdf/.7z URL",
234+
})
235+
236+
addFlag(cmdparser.FlagOptions{
237+
String: &c.useMechanism,
238+
DefaultString: "",
239+
Name: "use-mechanism",
240+
Usage: "Mechanism to use to make --use database online (attach, restore, dacfx)",
241+
})
242+
228243
addFlag(cmdparser.FlagOptions{
229244
String: &c.openTool,
230245
DefaultString: "",
@@ -281,6 +296,7 @@ func (c *MssqlBase) Run() {
281296
// command-line and the program will exit.
282297
func (c *MssqlBase) createContainer(imageName string, contextName string) {
283298
output := c.Cmd.Output()
299+
controller := container.NewController()
284300
saPassword := c.generatePassword()
285301

286302
env := []string{
@@ -294,8 +310,39 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
294310
}
295311

296312
// Do an early exit if url doesn't exist
297-
if c.usingDatabaseUrl != "" {
298-
mssqlcontainer.ValidateUsingUrlExists(c.usingDatabaseUrl, output)
313+
var useDatabase ingest.Ingest
314+
if c.useDatabaseUrl != "" {
315+
useDatabase = ingest.NewIngest(c.useDatabaseUrl, controller, ingest.IngestOptions{
316+
Mechanism: c.useMechanism,
317+
})
318+
319+
if !useDatabase.IsValidFileExtension() {
320+
output.FatalfWithHints(
321+
[]string{
322+
fmt.Sprintf(
323+
"--using must be a path to a file with a %q extension",
324+
strings.Join(useDatabase.ValidFileExtensions(), ", "),
325+
),
326+
},
327+
"%q is not a valid file extension for --using flag", useDatabase.UserProvidedFileExt())
328+
}
329+
330+
if useDatabase.IsRemoteUrl() && !useDatabase.IsValidScheme() {
331+
output.FatalfWithHints(
332+
[]string{
333+
fmt.Sprintf(
334+
"--using URL must one of %q",
335+
strings.Join(useDatabase.ValidSchemes(), ", "),
336+
),
337+
},
338+
"%q is not a valid URL for --using flag", c.useDatabaseUrl)
339+
}
340+
341+
if !useDatabase.SourceFileExists() {
342+
output.FatalfWithHints(
343+
[]string{fmt.Sprintf("File does not exist at URL %q", c.useDatabaseUrl)},
344+
"Unable to download file")
345+
}
299346
}
300347

301348
if c.defaultDatabase != "" {
@@ -304,8 +351,6 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
304351
}
305352
}
306353

307-
controller := container.NewController()
308-
309354
if !c.useCached {
310355
c.downloadImage(imageName, output, controller)
311356
}
@@ -344,8 +389,7 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
344389
config.CurrentContextName(),
345390
config.GetConfigFileUsed())
346391

347-
controller.ContainerWaitForLogEntry(
348-
containerId, c.errorLogEntryToWaitFor)
392+
controller.ContainerWaitForLogEntry(containerId, c.errorLogEntryToWaitFor)
349393

350394
output.Infof(
351395
"Disabled %q account (and rotated %q password). Creating user %q",
@@ -374,24 +418,23 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
374418
Name: "sa"}
375419

376420
c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master", Interactive: false})
377-
378421
c.createNonSaUser(userName, password)
379422

380423
// Download and restore DB if asked
381-
if c.usingDatabaseUrl != "" {
382-
mssqlcontainer.DownloadAndRestoreDb(
383-
controller,
384-
containerId,
385-
c.usingDatabaseUrl,
386-
userName,
387-
password,
388-
c.sql.Query,
389-
c.Cmd.Output(),
390-
)
424+
if useDatabase != nil {
425+
output.Infof("Copying to container")
426+
useDatabase.CopyToContainer(containerId)
427+
428+
if useDatabase.IsExtractionNeeded() {
429+
output.Infof("Extracting files from archive")
430+
useDatabase.Extract()
431+
}
432+
433+
output.Infof("Bringing database online")
434+
useDatabase.BringOnline(c.sql.Query, userName, password)
391435
}
392436

393437
if c.openTool == "" {
394-
395438
hints := [][]string{}
396439

397440
// TODO: sqlcmd open ads only support on Windows right now, add Mac support

cmd/modern/root/use.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
package root
55

66
import (
7+
"fmt"
78
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
89
"github.com/microsoft/go-sqlcmd/internal/config"
910
"github.com/microsoft/go-sqlcmd/internal/container"
1011
"github.com/microsoft/go-sqlcmd/internal/secret"
1112
"github.com/microsoft/go-sqlcmd/internal/sql"
12-
"github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer"
13+
"github.com/microsoft/go-sqlcmd/pkg/mssqlcontainer/ingest"
1314
)
1415

1516
type Use struct {
1617
cmdparser.Cmd
1718

18-
url string
19+
url string
20+
useMechanism string
1921

2022
sql sql.Sql
2123
}
@@ -39,6 +41,14 @@ func (c *Use) DefineCommand(...cmdparser.CommandOptions) {
3941
String: &c.url,
4042
Name: "url",
4143
Usage: "Name of context to set as current context"})
44+
45+
c.AddFlag(cmdparser.FlagOptions{
46+
String: &c.useMechanism,
47+
DefaultString: "",
48+
Name: "use-mechanism",
49+
Usage: "Mechanism to use to make --use database online (attach, restore, dacfx)",
50+
})
51+
4252
}
4353

4454
func (c *Use) run() {
@@ -64,16 +74,27 @@ func (c *Use) run() {
6474
c.sql = sql.New(sql.SqlOptions{UnitTesting: false})
6575
c.sql.Connect(endpoint, user, sql.ConnectOptions{Database: "master", Interactive: false})
6676

67-
mssqlcontainer.DownloadAndRestoreDb(
68-
controller,
69-
id,
70-
c.url,
77+
useDatabase := ingest.NewIngest(c.url, controller, ingest.IngestOptions{
78+
Mechanism: c.useMechanism,
79+
})
80+
81+
if !useDatabase.SourceFileExists() {
82+
output.FatalfWithHints(
83+
[]string{fmt.Sprintf("File does not exist at URL %q", c.url)},
84+
"Unable to download file to container")
85+
}
86+
87+
useDatabase.CopyToContainer(id)
88+
89+
if useDatabase.IsExtractionNeeded() {
90+
useDatabase.Extract()
91+
}
92+
93+
useDatabase.BringOnline(
94+
c.sql.Query,
7195
user.BasicAuth.Username,
7296
secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncryption),
73-
c.query,
74-
c.Cmd.Output(),
7597
)
76-
7798
} else {
7899
output.FatalfWithHintExamples([][]string{
79100
{"Create new context with a sql container ", "sqlcmd create mssql"},

internal/uri/factory.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package uri
2+
3+
import (
4+
"net/url"
5+
)
6+
7+
func NewUri(uri string) Uri {
8+
url, err := url.Parse(uri)
9+
if err != nil {
10+
panic(err)
11+
}
12+
13+
return Uri{
14+
uri: uri,
15+
url: url,
16+
}
17+
}

internal/uri/type.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package uri
2+
3+
import "net/url"
4+
5+
type Uri struct {
6+
uri string
7+
url *url.URL
8+
}

internal/uri/uri.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package uri
2+
3+
import (
4+
"path"
5+
"path/filepath"
6+
"strings"
7+
)
8+
9+
func (u Uri) IsLocal() bool {
10+
if len(u.Scheme()) > 2 {
11+
return false
12+
} else {
13+
return true
14+
}
15+
}
16+
17+
func (u Uri) ActualUrl() string {
18+
urlEndIdx := strings.LastIndex(u.uri, ".bak")
19+
if urlEndIdx == -1 {
20+
urlEndIdx = strings.LastIndex(u.uri, ".mdf")
21+
}
22+
if urlEndIdx != -1 {
23+
return u.uri[0:(urlEndIdx + 4)]
24+
}
25+
26+
if urlEndIdx == -1 {
27+
urlEndIdx = strings.LastIndex(u.uri, ".7z")
28+
if urlEndIdx != -1 {
29+
return u.uri[0:(urlEndIdx + 3)]
30+
}
31+
}
32+
33+
if urlEndIdx == -1 {
34+
urlEndIdx = strings.LastIndex(u.uri, ".bacpac")
35+
if urlEndIdx != -1 {
36+
return u.uri[0:(urlEndIdx + 7)]
37+
}
38+
}
39+
40+
return u.uri
41+
}
42+
43+
func (u Uri) Scheme() string {
44+
return u.url.Scheme
45+
}
46+
47+
func (u Uri) FileExtension() string {
48+
_, f := filepath.Split(u.ActualUrl())
49+
return strings.TrimLeft(filepath.Ext(f), ".")
50+
}
51+
52+
func (u Uri) Filename() string {
53+
return filepath.Base(u.ActualUrl())
54+
}
55+
56+
// parseDbName returns the databaseName from --using arg
57+
// It sets database name to the specified database name
58+
// or in absence of it, it is set to the filename without
59+
// extension.
60+
func (u Uri) ParseDbName() string {
61+
if u.uri == "" {
62+
panic("uri is empty")
63+
}
64+
65+
dbToken := path.Base(u.url.Path)
66+
if dbToken != "." && dbToken != "/" {
67+
lastIdx := strings.LastIndex(dbToken, ".bak")
68+
if lastIdx == -1 {
69+
lastIdx = strings.LastIndex(dbToken, ".mdf")
70+
}
71+
if lastIdx != -1 {
72+
//Get file name without extension
73+
fileName := dbToken[0:lastIdx]
74+
lastIdx += 5
75+
if lastIdx >= len(dbToken) {
76+
return fileName
77+
}
78+
//Return database name if it was specified
79+
return dbToken[lastIdx:]
80+
} else {
81+
lastIdx := strings.LastIndex(dbToken, ".bacpac")
82+
if lastIdx != -1 {
83+
//Get file name without extension
84+
fileName := dbToken[0:lastIdx]
85+
lastIdx += 8
86+
if lastIdx >= len(dbToken) {
87+
return fileName
88+
}
89+
//Return database name if it was specified
90+
return dbToken[lastIdx:]
91+
} else {
92+
lastIdx := strings.LastIndex(dbToken, ".7z")
93+
if lastIdx != -1 {
94+
//Get file name without extension
95+
fileName := dbToken[0:lastIdx]
96+
lastIdx += 4
97+
if lastIdx >= len(dbToken) {
98+
return fileName
99+
}
100+
//Return database name if it was specified
101+
return dbToken[lastIdx:]
102+
}
103+
}
104+
}
105+
}
106+
107+
fileName := filepath.Base(u.uri)
108+
return fileName[:len(fileName)-len(filepath.Ext(fileName))]
109+
}
110+
111+
func (u Uri) GetDbNameAsIdentifier() string {
112+
escapedDbName := strings.ReplaceAll(u.ParseDbName(), "'", "''")
113+
return strings.ReplaceAll(escapedDbName, "]", "]]")
114+
}
115+
116+
func (u Uri) GetDbNameAsNonIdentifier() string {
117+
return strings.ReplaceAll(u.ParseDbName(), "]", "]]")
118+
}

0 commit comments

Comments
 (0)