@@ -6,6 +6,8 @@ package config
66import (
77 "fmt"
88 "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
9+ "github.com/microsoft/go-sqlcmd/internal/container"
10+ "github.com/microsoft/go-sqlcmd/internal/sql"
911 "strings"
1012
1113 "github.com/microsoft/go-sqlcmd/internal/cmdparser"
@@ -17,6 +19,8 @@ import (
1719// ConnectionStrings implements the `sqlcmd config connection-strings` command
1820type ConnectionStrings struct {
1921 cmdparser.Cmd
22+
23+ database string
2024}
2125
2226func (c * ConnectionStrings ) DefineCommand (... cmdparser.CommandOptions ) {
@@ -36,6 +40,13 @@ func (c *ConnectionStrings) DefineCommand(...cmdparser.CommandOptions) {
3640 }
3741
3842 c .Cmd .DefineCommand (options )
43+
44+ c .AddFlag (cmdparser.FlagOptions {
45+ String : & c .database ,
46+ Name : "database" ,
47+ DefaultString : "" ,
48+ Shorthand : "d" ,
49+ Usage : "Database for the connection string (default is taken from the T/SQL login)" })
3950}
4051
4152// run generates connection strings for the current context in multiple formats.
@@ -48,12 +59,27 @@ func (c *ConnectionStrings) run() {
4859 "ADO.NET" : "Server=tcp:%s,%d;Initial Catalog=%s;Persist Security Info=False;User ID=%s;Password=%s;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=%s;Connection Timeout=30;" ,
4960 "JDBC" : "jdbc:sqlserver://%s:%d;database=%s;user=%s;password=%s;encrypt=true;trustServerCertificate=%s;loginTimeout=30;" ,
5061 "ODBC" : "Driver={ODBC Driver 18 for SQL Server};Server=tcp:%s,%d;Database=%s;Uid=%s;Pwd=%s;Encrypt=yes;TrustServerCertificate=%s;Connection Timeout=30;" ,
51- "GO" : "sqlserver://%s:%s@%s,%d?database=master ;encrypt=true;trustServerCertificate=%s;dial+timeout=30" ,
62+ "GO" : "sqlserver://%s:%s@%s,%d?database=%s ;encrypt=true;trustServerCertificate=%s;dial+timeout=30" ,
5263 "SQLCMD" : "sqlcmd -S %s,%d -U %s" ,
5364 }
5465
5566 endpoint , user := config .CurrentContext ()
5667
68+ if c .database == "" {
69+ if endpoint .AssetDetails != nil && endpoint .AssetDetails .ContainerDetails != nil {
70+ controller := container .NewController ()
71+ if controller .ContainerRunning (endpoint .AssetDetails .ContainerDetails .Id ) {
72+ s := sql .New (sql.SqlOptions {})
73+ s .Connect (endpoint , user , sql.ConnectOptions {Interactive : false })
74+ c .database = s .ScalarString ("PRINT DB_NAME()" )
75+ } else {
76+ c .database = "master"
77+ }
78+ } else {
79+ c .database = "master"
80+ }
81+ }
82+
5783 if user != nil {
5884 for k , v := range connectionStringFormats {
5985 if k == "GO" {
@@ -63,23 +89,25 @@ func (c *ConnectionStrings) run() {
6389 secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
6490 endpoint .EndpointDetails .Address ,
6591 endpoint .EndpointDetails .Port ,
92+ c .database ,
6693 c .stringForBoolean (c .trustServerCertificate (endpoint ), k ))
6794 } else if k == "SQLCMD" {
6895 format := pal .CmdLineWithEnvVars (
6996 []string {"SQLCMDPASSWORD=%s" },
70- "sqlcmd -S %s,%d -U %s" ,
97+ "sqlcmd -S %s,%d -U %s -d %s " ,
7198 )
7299
73100 connectionStringFormats [k ] = fmt .Sprintf (format ,
74101 secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
75102 endpoint .EndpointDetails .Address ,
76103 endpoint .EndpointDetails .Port ,
77- user .BasicAuth .Username )
104+ user .BasicAuth .Username ,
105+ c .database )
78106 } else {
79107 connectionStringFormats [k ] = fmt .Sprintf (v ,
80108 endpoint .EndpointDetails .Address ,
81109 endpoint .EndpointDetails .Port ,
82- "master" ,
110+ c . database ,
83111 user .BasicAuth .Username ,
84112 secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
85113 c .stringForBoolean (c .trustServerCertificate (endpoint ), k ))
0 commit comments