Skip to content

Commit 930c299

Browse files
authored
Feat: Add -e and -k, finish -X, -r, -L (#447)
1 parent ff6d88f commit 930c299

File tree

11 files changed

+221
-71
lines changed

11 files changed

+221
-71
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ The `sqlcmd` project aims to be a complete port of the original ODBC sqlcmd to t
125125
- There are new posix-style versions of each flag, such as `--input-file` for `-i`. `sqlcmd -?` will print those parameter names. Those new names do not preserve backward compatibility with ODBC `sqlcmd`. For example, to specify multiple input file names using `--input-file`, the file names must be comma-delimited, not space-delimited.
126126

127127
The following switches have different behavior in this version of `sqlcmd` compared to the original ODBC based `sqlcmd`.
128-
- `-r` requires a 0 or 1 argument
129128
- `-R` switch is ignored. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
130129
- `-I` switch is ignored; quoted identifiers are always set on. To disable quoted identifier behavior, add `SET QUOTED IDENTIFIER OFF` in your scripts.
131130
- `-N` now takes a string value that can be one of `true`, `false`, or `disable` to specify the encryption choice.
@@ -141,7 +140,6 @@ The following switches have different behavior in this version of `sqlcmd` compa
141140
- If using a single `-i` flag to pass multiple file names, there must be a space after the `-i`. Example: `-i file1.sql file2.sql`
142141
- `-M` switch is ignored. Sqlcmd always enables multi-subnet failover.
143142

144-
145143
### Switches not available in the new sqlcmd (go-sqlcmd) yet
146144

147145
There are a few switches yet to be implemented in the new `sqlcmd` (go-sqlcmd) compared

cmd/sqlcmd/sqlcmd.go

Lines changed: 131 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ type SQLCmdArguments struct {
3838
// Query to run then exit
3939
Query string
4040
Server string
41-
// Disable syscommands with a warning
42-
DisableCmdAndWarn bool
41+
// Disable syscommands with a warning or error
42+
DisableCmd *int
4343
// AuthenticationMethod is new for go-sqlcmd
4444
AuthenticationMethod string
4545
UseAad bool
@@ -55,7 +55,7 @@ type SQLCmdArguments struct {
5555
ErrorSeverityLevel uint8
5656
ErrorLevel int
5757
Format string
58-
ErrorsToStderr int
58+
ErrorsToStderr *int
5959
Headers int
6060
UnicodeOutputFile bool
6161
Version bool
@@ -66,25 +66,60 @@ type SQLCmdArguments struct {
6666
TrimSpaces bool
6767
Password string
6868
DedicatedAdminConnection bool
69-
ListServers bool
69+
ListServers string
70+
RemoveControlCharacters *int
71+
EchoInput bool
72+
QueryTimeout int
7073
// Keep Help at the end of the list
7174
Help bool
7275
}
7376

77+
func (args *SQLCmdArguments) useEnvVars() bool {
78+
return args.DisableCmd == nil
79+
}
80+
81+
func (args *SQLCmdArguments) errorOnBlockedCmd() bool {
82+
return args.DisableCmd != nil && *args.DisableCmd > 0
83+
}
84+
85+
func (args *SQLCmdArguments) warnOnBlockedCmd() bool {
86+
return args.DisableCmd != nil && *args.DisableCmd <= 0
87+
}
88+
89+
func (args *SQLCmdArguments) runStartupScript() bool {
90+
return args.DisableCmd == nil
91+
}
92+
93+
func (args *SQLCmdArguments) getControlCharacterBehavior() sqlcmd.ControlCharacterBehavior {
94+
if args.RemoveControlCharacters == nil {
95+
return sqlcmd.ControlIgnore
96+
}
97+
switch *args.RemoveControlCharacters {
98+
case 1:
99+
return sqlcmd.ControlReplace
100+
case 2:
101+
return sqlcmd.ControlReplaceConsecutive
102+
}
103+
return sqlcmd.ControlRemove
104+
}
105+
74106
const (
75-
sqlcmdErrorPrefix = "Sqlcmd: "
76-
applicationIntent = "application-intent"
77-
errorsToStderr = "errors-to-stderr"
78-
format = "format"
79-
encryptConnection = "encrypt-connection"
80-
screenWidth = "screen-width"
81-
fixedTypeWidth = "fixed-type-width"
82-
variableTypeWidth = "variable-type-width"
107+
sqlcmdErrorPrefix = "Sqlcmd: "
108+
applicationIntent = "application-intent"
109+
errorsToStderr = "errors-to-stderr"
110+
format = "format"
111+
encryptConnection = "encrypt-connection"
112+
screenWidth = "screen-width"
113+
fixedTypeWidth = "fixed-type-width"
114+
variableTypeWidth = "variable-type-width"
115+
disableCmdAndWarn = "disable-cmd-and-warn"
116+
listServers = "list-servers"
117+
removeControlCharacters = "remove-control-characters"
83118
)
84119

85120
// Validate arguments for settings not describe
86121
func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
87-
if a.ListServers {
122+
if a.ListServers != "" {
88123
c.Flags().Visit(func(f *pflag.Flag) {
89124
if f.Shorthand != "L" {
90125
err = localizer.Errorf("The -L parameter can not be used in combination with other parameters.")
@@ -110,6 +145,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
110145
err = rangeParameterError("-Y", fmt.Sprint(*a.FixedTypeWidth), 0, 8000, true)
111146
case a.VariableTypeWidth != nil && (*a.VariableTypeWidth < 0 || *a.VariableTypeWidth > 8000):
112147
err = rangeParameterError("-y", fmt.Sprint(*a.VariableTypeWidth), 0, 8000, true)
148+
case a.QueryTimeout < 0 || a.QueryTimeout > 65534:
149+
err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true)
113150
}
114151
}
115152
if err != nil {
@@ -170,9 +207,11 @@ func Execute(version string) {
170207
},
171208
Run: func(cmd *cobra.Command, argss []string) {
172209
// emulate -L returning no servers
173-
if args.ListServers {
174-
fmt.Println()
175-
fmt.Println(localizer.Sprintf("Servers:"))
210+
if args.ListServers != "" {
211+
if args.ListServers != "c" {
212+
fmt.Println()
213+
fmt.Println(localizer.Sprintf("Servers:"))
214+
}
176215
fmt.Println(" ;UID:Login ID=?;PWD:Password=?;Trusted_Connection:Use Integrated Security=?;*APP:AppName=?;*WSID:WorkStation ID=?;")
177216
os.Exit(0)
178217
}
@@ -181,7 +220,7 @@ func Execute(version string) {
181220
os.Exit(1)
182221
}
183222

184-
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
223+
vars := sqlcmd.InitializeVariables(args.useEnvVars())
185224
setVars(vars, &args)
186225

187226
if args.Version {
@@ -225,10 +264,11 @@ func Execute(version string) {
225264
}
226265

227266
// We need to rewrite the arguments to add -i and -v in front of each space-delimited value to be Cobra-friendly.
267+
// For flags like -r we need to inject the default value if the user omits it
228268
func convertOsArgs(args []string) (cargs []string) {
229269
flag := ""
230270
first := true
231-
for _, a := range args {
271+
for i, a := range args {
232272
if flag != "" {
233273
// If the user has a file named "-i" the only way they can pass it on the command line
234274
// is with triple quotes: sqlcmd -i """-i""" which will convince the flags parser to
@@ -240,11 +280,34 @@ func convertOsArgs(args []string) (cargs []string) {
240280
}
241281
first = false
242282
}
283+
var defValue string
243284
if isListFlag(a) {
244285
flag = a
245286
first = true
287+
} else {
288+
defValue = checkDefaultValue(args, i)
246289
}
247290
cargs = append(cargs, a)
291+
if defValue != "" {
292+
cargs = append(cargs, defValue)
293+
}
294+
}
295+
return
296+
}
297+
298+
// If args[i] is the given flag and args[i+1] is another flag, returns the value to append after the flag
299+
func checkDefaultValue(args []string, i int) (val string) {
300+
flags := map[rune]string{
301+
'r': "0",
302+
'k': "0",
303+
'L': "|", // | is the sentinel for no value since users are unlikely to use it. It's "reserved" in most shells
304+
'X': "0",
305+
}
306+
if isFlag(args[i]) && (len(args) == i+1 || args[i+1][0] == '-') {
307+
if v, ok := flags[rune(args[i][1])]; ok {
308+
val = v
309+
return
310+
}
248311
}
249312
return
250313
}
@@ -296,6 +359,9 @@ func SetScreenWidthFlags(args *SQLCmdArguments, rootCmd *cobra.Command) {
296359
args.ScreenWidth = getOptionalIntArgument(rootCmd, screenWidth)
297360
args.FixedTypeWidth = getOptionalIntArgument(rootCmd, fixedTypeWidth)
298361
args.VariableTypeWidth = getOptionalIntArgument(rootCmd, variableTypeWidth)
362+
args.DisableCmd = getOptionalIntArgument(rootCmd, disableCmdAndWarn)
363+
args.ErrorsToStderr = getOptionalIntArgument(rootCmd, errorsToStderr)
364+
args.RemoveControlCharacters = getOptionalIntArgument(rootCmd, removeControlCharacters)
299365
}
300366

301367
func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
@@ -313,7 +379,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
313379
rootCmd.Flags().StringVarP(&args.InitialQuery, "initial-query", "q", "", localizer.Sprintf("Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed"))
314380
rootCmd.Flags().StringVarP(&args.Query, "query", "Q", "", localizer.Sprintf("Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed"))
315381
rootCmd.Flags().StringVarP(&args.Server, "server", "S", "", localizer.Sprintf("%s Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable %s.", localizer.ConnStrPattern, localizer.ServerEnvVar))
316-
rootCmd.Flags().BoolVarP(&args.DisableCmdAndWarn, "disable-cmd-and-warn", "X", false, localizer.Sprintf("Disables commands that might compromise system security. Sqlcmd issues a warning and continues"))
382+
_ = rootCmd.Flags().IntP(disableCmdAndWarn, "X", 0, localizer.Sprintf("%s Disables commands that might compromise system security. Passing 1 tells sqlcmd to exit when disabled commands are run.", "-X[1]"))
317383
rootCmd.Flags().StringVar(&args.AuthenticationMethod, "authentication-method", "", localizer.Sprintf("Specifies the SQL authentication method to use to connect to Azure SQL Database. One of: ActiveDirectoryDefault, ActiveDirectoryIntegrated, ActiveDirectoryPassword, ActiveDirectoryInteractive, ActiveDirectoryManagedIdentity, ActiveDirectoryServicePrincipal, SqlPassword"))
318384
rootCmd.Flags().BoolVarP(&args.UseAad, "use-aad", "G", false, localizer.Sprintf("Tells sqlcmd to use ActiveDirectory authentication. If no user name is provided, authentication method ActiveDirectoryDefault is used. If a password is provided, ActiveDirectoryPassword is used. Otherwise ActiveDirectoryInteractive is used"))
319385
rootCmd.Flags().BoolVarP(&args.DisableVariableSubstitution, "disable-variable-substitution", "x", false, localizer.Sprintf("Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many %s statements that may contain strings that have the same format as regular variables, such as $(variable_name)", localizer.InsertKeyword))
@@ -328,8 +394,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
328394
// Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866
329395
//rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true"
330396
rootCmd.Flags().StringVarP(&args.Format, format, "F", "horiz", localizer.Sprintf("Specifies the formatting for results"))
331-
rootCmd.Flags().IntVarP(&args.ErrorsToStderr, errorsToStderr, "r", -1, localizer.Sprintf("Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent"))
332-
//rootCmd.Flags().Lookup(errorsToStderr).NoOptDefVal = "0"
397+
_ = rootCmd.Flags().IntP(errorsToStderr, "r", -1, localizer.Sprintf("%s Redirects error messages with severity >= 11 output to stderr. Pass 1 to to redirect all errors including PRINT.", "-r[0 | 1]"))
333398
rootCmd.Flags().IntVar(&args.DriverLoggingLevel, "driver-logging-level", 0, localizer.Sprintf("Level of mssql driver messages to print"))
334399
rootCmd.Flags().BoolVarP(&args.ExitOnError, "exit-on-error", "b", false, localizer.Sprintf("Specifies that sqlcmd exits and returns a %s value when an error occurs", localizer.DosErrorLevel))
335400
rootCmd.Flags().IntVarP(&args.ErrorLevel, "error-level", "m", 0, localizer.Sprintf("Controls which error messages are sent to %s. Messages that have severity level greater than or equal to this level are sent", localizer.StdoutName))
@@ -350,10 +415,13 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
350415
_ = rootCmd.Flags().IntP(screenWidth, "w", 0, localizer.Sprintf("Specifies the screen width for output"))
351416
_ = rootCmd.Flags().IntP(variableTypeWidth, "y", 256, setScriptVariable("SQLCMDMAXVARTYPEWIDTH"))
352417
_ = rootCmd.Flags().IntP(fixedTypeWidth, "Y", 0, setScriptVariable("SQLCMDMAXFIXEDTYPEWIDTH"))
353-
rootCmd.Flags().BoolVarP(&args.ListServers, "list-servers", "L", false, "List servers")
418+
rootCmd.Flags().StringVarP(&args.ListServers, listServers, "L", "", localizer.Sprintf("%s List servers. Pass %s to omit 'Servers:' output.", "-L[c]", "c"))
354419
rootCmd.Flags().BoolVarP(&args.DedicatedAdminConnection, "dedicated-admin-connection", "A", false, localizer.Sprintf("Dedicated administrator connection"))
355420
_ = rootCmd.Flags().BoolP("enable-quoted-identifiers", "I", true, localizer.Sprintf("Provided for backward compatibility. Quoted identifiers are always enabled"))
356421
_ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used"))
422+
_ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]"))
423+
rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input"))
424+
rootCmd.Flags().IntVarP(&args.QueryTimeout, "query-timeout", "t", 0, "Query timeout")
357425
}
358426

359427
func setScriptVariable(v string) string {
@@ -403,7 +471,32 @@ func normalizeFlags(cmd *cobra.Command) error {
403471
err = invalidParameterError("-r", v, "0", "1")
404472
return pflag.NormalizedName("")
405473
}
474+
case disableCmdAndWarn:
475+
switch v {
476+
case "0", "1":
477+
return pflag.NormalizedName(name)
478+
default:
479+
err = invalidParameterError("-X", v, "1")
480+
return pflag.NormalizedName("")
481+
}
482+
case listServers:
483+
switch v {
484+
case "|", "c":
485+
return pflag.NormalizedName(name)
486+
default:
487+
err = invalidParameterError("-L", v, "c")
488+
return pflag.NormalizedName("")
489+
}
490+
case removeControlCharacters:
491+
switch v {
492+
case "0", "1", "2":
493+
return pflag.NormalizedName(name)
494+
default:
495+
err = invalidParameterError("-k", v, "1", "2")
496+
return pflag.NormalizedName("")
497+
}
406498
}
499+
407500
return pflag.NormalizedName(name)
408501
})
409502
if err != nil {
@@ -513,7 +606,7 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
513606
return ""
514607
},
515608
sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName },
516-
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" },
609+
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return fmt.Sprint(a.QueryTimeout) },
517610
sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return fmt.Sprint(a.Headers) },
518611
sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string {
519612
if a.ColumnSeparator != "" {
@@ -558,7 +651,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
558651
connect.ApplicationName = "sqlcmd"
559652
if len(args.Password) > 0 {
560653
connect.Password = args.Password
561-
} else if !args.DisableCmdAndWarn {
654+
} else if args.useEnvVars() {
562655
connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
563656
}
564657
connect.ServerName = args.Server
@@ -576,7 +669,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
576669
connect.UseTrustedConnection = args.UseTrustedConnection
577670
connect.TrustServerCertificate = args.TrustServerCertificate
578671
connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "")
579-
connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
672+
connect.DisableEnvironmentVariables = !args.useEnvVars()
580673
connect.DisableVariableSubstitution = args.DisableVariableSubstitution
581674
connect.ApplicationIntent = args.ApplicationIntent
582675
connect.LoginTimeoutSeconds = args.LoginTimeout
@@ -614,10 +707,10 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
614707
defer s.StopCloseHandler()
615708
s.UnicodeOutputFile = args.UnicodeOutputFile
616709

617-
if args.DisableCmdAndWarn {
618-
s.Cmd.DisableSysCommands(false)
710+
if args.DisableCmd != nil {
711+
s.Cmd.DisableSysCommands(args.errorOnBlockedCmd())
619712
}
620-
713+
s.EchoInput = args.EchoInput
621714
if args.BatchTerminator != "GO" {
622715
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
623716
if err != nil {
@@ -629,25 +722,24 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
629722
}
630723

631724
s.Connect = &connectConfig
632-
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces)
725+
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior())
633726
if args.OutputFile != "" {
634727
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
635728
if err != nil {
636729
return 1, err
637730
}
638-
} else {
731+
} else if args.ErrorsToStderr != nil {
639732
var stderrSeverity uint8 = 11
640-
if args.ErrorsToStderr == 1 {
733+
if *args.ErrorsToStderr == 1 {
641734
stderrSeverity = 0
642735
}
643-
if args.ErrorsToStderr >= 0 {
644-
s.PrintError = func(msg string, severity uint8) bool {
645-
if severity >= stderrSeverity {
646-
s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol))
647-
return true
648-
}
649-
return false
736+
737+
s.PrintError = func(msg string, severity uint8) bool {
738+
if severity >= stderrSeverity {
739+
s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol))
740+
return true
650741
}
742+
return false
651743
}
652744
}
653745

@@ -659,7 +751,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
659751
}
660752

661753
script := vars.StartupScriptFile()
662-
if !args.DisableCmdAndWarn && len(script) > 0 {
754+
if args.runStartupScript() && len(script) > 0 {
663755
f, fileErr := os.Open(script)
664756
if fileErr != nil {
665757
s.WriteError(s.GetError(), sqlcmd.InvalidVariableValue(sqlcmd.SQLCMDINI, script))

0 commit comments

Comments
 (0)