Skip to content

Commit 164f8e1

Browse files
authored
implement :xml (#336)
1 parent 5f902b7 commit 164f8e1

File tree

6 files changed

+87
-14
lines changed

6 files changed

+87
-14
lines changed

internal/color/color.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ const (
3030
TextTypeError
3131
// TextTypeWarning is for warning messages
3232
TextTypeWarning
33+
// TextTypeXml indicates the content is XML
34+
TextTypeXml
3335
)
3436

3537
var typeMap map[TextType]chroma.TokenType = map[TextType]chroma.TokenType{
@@ -85,6 +87,10 @@ func (c *chromaColorizer) Write(w io.Writer, s string, scheme string, t TextType
8587
if err = quick.Highlight(w, s, "transact-sql", "terminal16m", scheme); err != nil {
8688
_, err = w.Write([]byte(s))
8789
}
90+
case TextTypeXml:
91+
if err = quick.Highlight(w, s, "xml", "terminal16m", scheme); err != nil {
92+
_, err = w.Write([]byte(s))
93+
}
8894
default:
8995
tokens := chroma.Literator(chroma.Token{
9096
Type: typeMap[t], Value: s})

internal/color/color_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ func TestWrite(t *testing.T) {
5454
args: args{s: "warn", t: TextTypeWarning},
5555
wantW: "\x1b[3mwarn\x1b[0m",
5656
},
57+
{
58+
name: "XML",
59+
args: args{s: "<node>value</node>", t: TextTypeXml},
60+
wantW: "\x1b[1m\x1b[38;2;0;128;0m<node>\x1b[0mvalue\x1b[1m\x1b[38;2;0;128;0m</node>\x1b[0m",
61+
},
5762
}
5863

5964
for _, tt := range tests {

pkg/sqlcmd/commands.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ func newCommands() Commands {
108108
action: onerrorCommand,
109109
name: "ONERROR",
110110
},
111+
"XML": {
112+
regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`),
113+
action: xmlCommand,
114+
name: "XML",
115+
},
111116
}
112117
}
113118

@@ -368,10 +373,16 @@ func listCommand(s *Sqlcmd, args []string, line uint) (err error) {
368373
}
369374
output := s.GetOutput()
370375
if cmd == "color" {
376+
sample := "select 'literal' as literal, 100 as number from [sys].[tables]"
377+
clr := color.TextTypeTSql
378+
if s.Format.IsXmlMode() {
379+
sample = `<node att="attValue"/><node>value</node>`
380+
clr = color.TextTypeXml
381+
}
371382
// ignoring errors since it's not critical output
372383
for _, style := range s.colorizer.Styles() {
373384
_, _ = output.Write([]byte(style + ": "))
374-
_ = s.colorizer.Write(output, "select 'literal' as literal, 100 as number from [sys].[tables]", style, color.TextTypeTSql)
385+
_ = s.colorizer.Write(output, sample, style, clr)
375386
_, _ = output.Write([]byte(SqlcmdEol))
376387
}
377388
return
@@ -507,6 +518,22 @@ func onerrorCommand(s *Sqlcmd, args []string, line uint) error {
507518
return nil
508519
}
509520

521+
func xmlCommand(s *Sqlcmd, args []string, line uint) error {
522+
if len(args) != 1 || args[0] == "" {
523+
return InvalidCommandError("XML", line)
524+
}
525+
params := strings.TrimSpace(args[0])
526+
// "OFF" and "ON" are documented as the allowed values.
527+
// ODBC sqlcmd treats any value other than "ON" the same as "OFF".
528+
// So we will too.
529+
if strings.EqualFold(params, "on") {
530+
s.Format.XmlMode(true)
531+
} else {
532+
s.Format.XmlMode(false)
533+
}
534+
return nil
535+
}
536+
510537
func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
511538
var b *strings.Builder
512539
end := len(arg)

pkg/sqlcmd/commands_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func TestCommandParsing(t *testing.T) {
5151
{`:!!notepad`, "EXEC", []string{"notepad"}},
5252
{` !! dir c:\`, "EXEC", []string{` dir c:\`}},
5353
{`!!dir c:\`, "EXEC", []string{`dir c:\`}},
54+
{`:XML ON `, "XML", []string{`ON `}},
5455
}
5556

5657
for _, test := range commands {
@@ -187,6 +188,7 @@ func TestListCommandUsesColorizer(t *testing.T) {
187188
func TestListColorPrintsStyleSamples(t *testing.T) {
188189
vars := InitializeVariables(false)
189190
s := New(nil, "", vars)
191+
s.Format = NewSQLCmdDefaultFormatter(false)
190192
// force colorizer on
191193
s.colorizer = color.New(true)
192194
buf := &memoryBuffer{buf: new(bytes.Buffer)}

pkg/sqlcmd/format.go

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ type Formatter interface {
3636
AddMessage(string)
3737
// AddError is called for each error encountered during batch execution
3838
AddError(err error)
39+
// XmlMode enables or disables XML rendering mode
40+
XmlMode(enable bool)
41+
// IsXmlMode returns whether XML mode is enabled
42+
IsXmlMode() bool
3943
}
4044

4145
// ControlCharacterBehavior specifies the text handling required for control characters in the output
@@ -77,6 +81,7 @@ type sqlCmdFormatterType struct {
7781
format string
7882
maxColNameLen int
7983
colorizer color.Colorizer
84+
xml bool
8085
}
8186

8287
// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter
@@ -119,7 +124,7 @@ func (f *sqlCmdFormatterType) writeOut(s string, t color.TextType) {
119124
}
120125
}
121126

122-
// Stores the settings to use for processing the current batch
127+
// BeginBatch stores the settings to use for processing the current batch
123128
// TODO: add a third io.Writer for messages when we add -r support
124129
func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) {
125130
f.out = out
@@ -138,17 +143,19 @@ func (f *sqlCmdFormatterType) EndBatch() {
138143
func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) {
139144
f.rowcount = 0
140145
f.columnDetails, f.maxColNameLen = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth())
141-
if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" {
146+
if f.vars.RowsBetweenHeaders() > -1 && f.format == "horizontal" && !f.xml {
142147
f.printColumnHeadings()
143148
}
144149
}
145150

146-
// Writes a blank line to the designated output writer
151+
// EndResultSet writes a blank line to the designated output writer
147152
func (f *sqlCmdFormatterType) EndResultSet() {
148-
f.writeOut(SqlcmdEol, color.TextTypeNormal)
153+
if !f.xml {
154+
f.writeOut(SqlcmdEol, color.TextTypeNormal)
155+
}
149156
}
150157

151-
// Writes the current row to the designated output writer
158+
// AddRow writes the current row to the designated output writer
152159
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
153160
retval := ""
154161
values, err := f.scanRow(row)
@@ -157,7 +164,9 @@ func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
157164
return retval
158165
}
159166
retval = values[0]
160-
if f.format == "horizontal" {
167+
if f.xml {
168+
f.printColumnValue(retval, 0)
169+
} else if f.format == "horizontal" {
161170
// values are the full values, look at the displaywidth of each column and truncate accordingly
162171
for i, v := range values {
163172
if i > 0 {
@@ -176,7 +185,6 @@ func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
176185
}
177186
f.writeOut(SqlcmdEol, color.TextTypeNormal)
178187
return retval
179-
180188
}
181189

182190
func (f *sqlCmdFormatterType) addVerticalRow(values []string) {
@@ -193,12 +201,14 @@ func (f *sqlCmdFormatterType) addVerticalRow(values []string) {
193201
}
194202
}
195203

196-
// Writes a non-error message to the designated message writer
204+
// AddMessage writes a non-error message to the designated message writer
197205
func (f *sqlCmdFormatterType) AddMessage(msg string) {
198-
f.mustWriteOut(msg+SqlcmdEol, color.TextTypeWarning)
206+
if !f.xml {
207+
f.mustWriteOut(msg+SqlcmdEol, color.TextTypeWarning)
208+
}
199209
}
200210

201-
// Writes an error to the designated err Writer
211+
// AddError writes an error to the designated err Writer
202212
func (f *sqlCmdFormatterType) AddError(err error) {
203213
print := true
204214
b := new(strings.Builder)
@@ -217,6 +227,16 @@ func (f *sqlCmdFormatterType) AddError(err error) {
217227
}
218228
}
219229

230+
// XmlMode enables or disables XML mode
231+
func (f *sqlCmdFormatterType) XmlMode(enable bool) {
232+
f.xml = enable
233+
}
234+
235+
// IsXmlMode returns whether XML mode is enabled
236+
func (f *sqlCmdFormatterType) IsXmlMode() bool {
237+
return f.xml
238+
}
239+
220240
// Prints column headings based on columnDetail, variables, and command line arguments
221241
func (f *sqlCmdFormatterType) printColumnHeadings() {
222242
names := new(strings.Builder)
@@ -535,7 +555,7 @@ func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {
535555

536556
s.WriteString(val)
537557
r := []rune(val)
538-
if f.format == "horizontal" {
558+
if !f.xml && f.format == "horizontal" {
539559
if !f.removeTrailingSpaces {
540560
if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) {
541561
padding := c.displayWidth - min64(c.displayWidth, int64(len(r)))
@@ -551,11 +571,15 @@ func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {
551571

552572
r = []rune(s.String())
553573
}
554-
if c.displayWidth > 0 && int64(len(r)) > c.displayWidth {
574+
if !f.xml && (c.displayWidth > 0 && int64(len(r)) > c.displayWidth) {
555575
s.Reset()
556576
s.WriteString(string(r[:c.displayWidth]))
557577
}
558-
f.writeOut(s.String(), color.TextTypeCell)
578+
clr := color.TextTypeCell
579+
if f.xml {
580+
clr = color.TextTypeXml
581+
}
582+
f.writeOut(s.String(), clr)
559583
}
560584

561585
func (f *sqlCmdFormatterType) mustWriteOut(s string, t color.TextType) {

pkg/sqlcmd/format_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,12 @@ func TestFormatterColorizer(t *testing.T) {
149149
assert.NoError(t, err, "runSqlCmd returned error")
150150
assert.Equal(t, "\x1b[38;2;0;128;0mname\x1b[0m"+SqlcmdEol+SqlcmdEol+"\x1b[3m(1 row affected)"+SqlcmdEol+"\x1b[0m", buf.buf.String())
151151
}
152+
153+
func TestFormatterXmlMode(t *testing.T) {
154+
s, buf := setupSqlCmdWithMemoryOutput(t)
155+
defer buf.Close()
156+
s.Format.XmlMode(true)
157+
err := runSqlCmd(t, s, []string{"select name from sys.databases where name='master' for xml auto ", "GO"})
158+
assert.NoError(t, err, "runSqlCmd returned error")
159+
assert.Equal(t, `<sys.databases name="master"/>`+SqlcmdEol, buf.buf.String())
160+
}

0 commit comments

Comments
 (0)