Skip to content

Commit 656b385

Browse files
Luis Silvaluisfvieirasilva
authored andcommitted
refactor(shell): create internal/shell with internal functions from pkg/shell
The idea is that our tests could use internal/shell package to avoid open a new connection to the database everytime
1 parent a66498b commit 656b385

File tree

7 files changed

+250
-214
lines changed

7 files changed

+250
-214
lines changed

internal/cmd/root.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/spf13/cobra"
99

1010
"github.com/libsql/libsql-shell-go/pkg/shell"
11+
"github.com/libsql/libsql-shell-go/pkg/shell/enums"
1112
)
1213

1314
type RootArgs struct {
@@ -28,7 +29,7 @@ func NewRootCmd() *cobra.Command {
2829
InF: cmd.InOrStdin(),
2930
OutF: cmd.OutOrStdout(),
3031
ErrF: cmd.ErrOrStderr(),
31-
HistoryMode: shell.PerDatabaseHistory,
32+
HistoryMode: enums.PerDatabaseHistory,
3233
HistoryName: "libsql",
3334
QuietMode: rootArgs.quiet,
3435
}

pkg/shell/history.go renamed to internal/shell/history.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,17 @@ import (
77
"path/filepath"
88

99
"github.com/libsql/libsql-shell-go/internal/db"
10+
"github.com/libsql/libsql-shell-go/pkg/shell/enums"
1011
"github.com/libsql/libsql-shell-go/pkg/shell/shellerrors"
1112
)
1213

13-
type HistoryMode int64
14-
15-
const (
16-
SingleHistory HistoryMode = 0
17-
PerDatabaseHistory HistoryMode = 1
18-
LocalHistory HistoryMode = 2
19-
)
20-
21-
func GetHistoryFileBasedOnMode(dbPath string, mode HistoryMode, historyName string) string {
14+
func GetHistoryFileBasedOnMode(dbPath string, mode enums.HistoryMode, historyName string) string {
2215
sharedHistoryFileName := getHistoryFileName(historyName)
2316

2417
switch mode {
25-
case LocalHistory:
18+
case enums.LocalHistory:
2619
return sharedHistoryFileName
27-
case PerDatabaseHistory:
20+
case enums.PerDatabaseHistory:
2821
if parsedName, err := parseNameFromDbPath(dbPath); err == nil && parsedName != "" {
2922
return getHistoryFileFullPath(historyName, getHistoryFileName(parsedName))
3023
}

pkg/shell/history_test.go renamed to internal/shell/history_test.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import (
66
"testing"
77

88
qt "github.com/frankban/quicktest"
9-
"github.com/libsql/libsql-shell-go/pkg/shell"
9+
"github.com/libsql/libsql-shell-go/internal/shell"
10+
"github.com/libsql/libsql-shell-go/pkg/shell/enums"
1011
)
1112

1213
const historyName = "libsql"
@@ -22,7 +23,7 @@ func TestGetHistoryFileBasedOnMode_GivenLocalHistory_WhenPathIsEmpty_ExpectShare
2223

2324
dbPath := ""
2425
expectedPath := sharedHistoryFileName
25-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.LocalHistory, historyName)
26+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.LocalHistory, historyName)
2627

2728
c.Assert(result, qt.Equals, expectedPath)
2829
}
@@ -32,7 +33,7 @@ func TestGetHistoryFileBasedOnMode_GivenLocalHistory_WhenPathIsValid_ExpectShare
3233

3334
dbPath := "/path/to/my/db.sqlite"
3435
expectedPath := sharedHistoryFileName
35-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.LocalHistory, historyName)
36+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.LocalHistory, historyName)
3637

3738
c.Assert(result, qt.Equals, expectedPath)
3839
}
@@ -42,7 +43,7 @@ func TestGetHistoryFileBasedOnMode_GivenSingleHistory_WhenPathIsValid_ExpectShar
4243

4344
dbPath := "/path/to/my/db.sqlite"
4445
expectedPath := getExpectedHistoryFullPath(historyName)
45-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.SingleHistory, historyName)
46+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.SingleHistory, historyName)
4647

4748
c.Assert(result, qt.Equals, expectedPath)
4849
}
@@ -52,7 +53,7 @@ func TestGetHistoryFileBasedOnMode_GivenSingleHistory_WhenPathIsEmpty_ExpectShar
5253

5354
dbPath := ""
5455
expectedPath := getExpectedHistoryFullPath(historyName)
55-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.SingleHistory, historyName)
56+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.SingleHistory, historyName)
5657

5758
c.Assert(result, qt.Equals, expectedPath)
5859
}
@@ -62,7 +63,7 @@ func TestGetHistoryFileBasedOnMode_GivenPerDatabaseHistory_WhenPathIsValid_Expec
6263

6364
dbPath := "/path/to/my/db.sqlite"
6465
expectedPath := getExpectedHistoryFullPath("db")
65-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.PerDatabaseHistory, historyName)
66+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.PerDatabaseHistory, historyName)
6667

6768
c.Assert(result, qt.Equals, expectedPath)
6869
}
@@ -72,7 +73,7 @@ func TestGetHistoryFileBasedOnMode_GivenPerDatabaseHistory_WhenPathIsEmpty_Expec
7273

7374
dbPath := ""
7475
expectedPath := getExpectedHistoryFullPath(historyName)
75-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.PerDatabaseHistory, historyName)
76+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.PerDatabaseHistory, historyName)
7677

7778
c.Assert(result, qt.Equals, expectedPath)
7879
}
@@ -82,7 +83,7 @@ func TestGetHistoryFileBasedOnMode_GivenPerDatabaseHistory_WhenPathIsHttpUrl_Exp
8283

8384
dbPath := "https://username:password@company.turso.io"
8485
expectedPath := getExpectedHistoryFullPath("username:password")
85-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.PerDatabaseHistory, historyName)
86+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.PerDatabaseHistory, historyName)
8687

8788
c.Assert(result, qt.Equals, expectedPath)
8889
}
@@ -92,7 +93,7 @@ func TestGetHistoryFileBasedOnMode_GivenPerDatabaseHistory_WhenPathIsHttpUrlWith
9293

9394
dbPath := "https://company.turso.io"
9495
expectedPath := getExpectedHistoryFullPath(historyName)
95-
result := shell.GetHistoryFileBasedOnMode(dbPath, shell.PerDatabaseHistory, historyName)
96+
result := shell.GetHistoryFileBasedOnMode(dbPath, enums.PerDatabaseHistory, historyName)
9697

9798
c.Assert(result, qt.Equals, expectedPath)
9899
}

internal/shell/shell.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package shell
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"regexp"
7+
"strings"
8+
9+
"github.com/chzyer/readline"
10+
"github.com/fatih/color"
11+
"github.com/libsql/libsql-shell-go/internal/db"
12+
"github.com/libsql/libsql-shell-go/internal/shellcmd"
13+
"github.com/libsql/libsql-shell-go/pkg/shell/enums"
14+
"github.com/spf13/cobra"
15+
)
16+
17+
const QUIT_COMMAND = ".quit"
18+
const DEFAULT_WELCOME_MESSAGE = "Welcome to LibSQL shell!\n\nType \".quit\" to exit the shell, and \".help\" to show all commands\n\n"
19+
20+
const promptNewStatement = "→ "
21+
const promptContinueStatement = "... "
22+
23+
type ShellConfig struct {
24+
InF io.Reader
25+
OutF io.Writer
26+
ErrF io.Writer
27+
HistoryMode enums.HistoryMode
28+
HistoryName string
29+
QuietMode bool
30+
WelcomeMessage *string
31+
}
32+
33+
type Shell struct {
34+
config ShellConfig
35+
36+
db *db.Db
37+
promptFmt func(p ...interface{}) string
38+
39+
state shellState
40+
41+
databaseCmd *cobra.Command
42+
}
43+
44+
type shellState struct {
45+
readline *readline.Instance
46+
statementParts []string
47+
insideMultilineStatement bool
48+
interruptReadEvalPrintLoop bool
49+
printMode enums.PrintMode
50+
}
51+
52+
func NewShell(config ShellConfig, db *db.Db) (*Shell, error) {
53+
promptFmt := color.New(color.FgBlue, color.Bold).SprintFunc()
54+
55+
newShell := Shell{config: config, db: db, promptFmt: promptFmt}
56+
57+
dbCmdConfig := &shellcmd.DbCmdConfig{
58+
Db: db,
59+
OutF: config.OutF,
60+
ErrF: config.ErrF,
61+
SetInterruptShell: func() { newShell.state.interruptReadEvalPrintLoop = true },
62+
SetMode: func(mode enums.PrintMode) { newShell.state.printMode = mode },
63+
GetMode: func() enums.PrintMode {
64+
return newShell.state.printMode
65+
},
66+
}
67+
newShell.databaseCmd = shellcmd.CreateNewDatabaseRootCmd(dbCmdConfig)
68+
69+
err := newShell.resetState()
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
return &newShell, nil
75+
}
76+
77+
func (sh *Shell) Run() error {
78+
err := sh.resetState()
79+
if err != nil {
80+
return err
81+
}
82+
defer sh.state.readline.Close()
83+
84+
if !sh.config.QuietMode {
85+
fmt.Print(sh.getWelcomeMessage())
86+
}
87+
88+
for !sh.state.interruptReadEvalPrintLoop {
89+
line, err := sh.state.readline.Readline()
90+
91+
if err == readline.ErrInterrupt {
92+
if len(line) == 0 {
93+
return nil
94+
} else {
95+
continue
96+
}
97+
} else if err == io.EOF {
98+
break
99+
}
100+
101+
line = strings.TrimSpace(line)
102+
103+
switch {
104+
case len(line) == 0:
105+
continue
106+
case sh.state.insideMultilineStatement:
107+
sh.appendStatementPartAndExecuteIfFinished(line)
108+
case isCommand(line):
109+
err = sh.executeCommand(line)
110+
if err != nil {
111+
db.PrintError(err, sh.config.ErrF)
112+
}
113+
default:
114+
sh.appendStatementPartAndExecuteIfFinished(line)
115+
}
116+
117+
}
118+
return nil
119+
}
120+
121+
func (sh *Shell) resetState() error {
122+
var err error
123+
sh.state.readline, err = sh.newReadline()
124+
if err != nil {
125+
return err
126+
}
127+
sh.state.readline.CaptureExitSignal()
128+
129+
sh.state.insideMultilineStatement = false
130+
sh.state.statementParts = make([]string, 0)
131+
132+
sh.state.interruptReadEvalPrintLoop = false
133+
134+
sh.state.printMode = enums.TABLE_MODE
135+
136+
return nil
137+
}
138+
139+
func (sh *Shell) newReadline() (*readline.Instance, error) {
140+
historyFile := GetHistoryFileBasedOnMode(sh.db.Path, sh.config.HistoryMode, sh.config.HistoryName)
141+
142+
return readline.NewEx(&readline.Config{
143+
Prompt: sh.promptFmt(promptNewStatement),
144+
InterruptPrompt: "^C",
145+
HistoryFile: historyFile,
146+
EOFPrompt: QUIT_COMMAND,
147+
Stdin: io.NopCloser(sh.config.InF),
148+
Stdout: sh.config.OutF,
149+
Stderr: sh.config.ErrF,
150+
})
151+
}
152+
153+
func isCommand(line string) bool {
154+
return line[0] == '.'
155+
}
156+
157+
func (sh *Shell) executeCommand(command string) error {
158+
parts := strings.Fields(command)
159+
sh.databaseCmd.SetArgs(parts)
160+
161+
err := sh.databaseCmd.Execute()
162+
163+
if err != nil && strings.HasPrefix(err.Error(), "unknown command") {
164+
rx := regexp.MustCompile(`"[^"]*"`)
165+
command := rx.FindString(fmt.Sprint(err))
166+
return fmt.Errorf(`unknown command or invalid arguments: %s. Enter ".help" for help`, command)
167+
}
168+
return err
169+
}
170+
171+
func (sh *Shell) appendStatementPartAndExecuteIfFinished(statementPart string) {
172+
sh.state.statementParts = append(sh.state.statementParts, statementPart)
173+
if strings.HasSuffix(statementPart, ";") {
174+
completeStatement := strings.Join(sh.state.statementParts, "\n")
175+
sh.state.statementParts = make([]string, 0)
176+
sh.state.insideMultilineStatement = false
177+
sh.state.readline.SetPrompt(sh.promptFmt(promptNewStatement))
178+
err := sh.db.ExecuteAndPrintStatements(completeStatement, sh.config.OutF, false, sh.state.printMode)
179+
if err != nil {
180+
db.PrintError(err, sh.state.readline.Stderr())
181+
}
182+
} else {
183+
sh.state.readline.SetPrompt(sh.promptFmt(promptContinueStatement))
184+
sh.state.insideMultilineStatement = true
185+
}
186+
}
187+
188+
func (sh *Shell) ExecuteCommandOrStatements(commandOrStatements string) error {
189+
if isCommand(commandOrStatements) {
190+
return sh.executeCommand(commandOrStatements)
191+
}
192+
193+
return sh.db.ExecuteAndPrintStatements(commandOrStatements, sh.config.OutF, false, sh.state.printMode)
194+
}
195+
196+
func (sh *Shell) getWelcomeMessage() string {
197+
if sh.config.WelcomeMessage == nil {
198+
return DEFAULT_WELCOME_MESSAGE
199+
}
200+
return *sh.config.WelcomeMessage
201+
}

pkg/shell/enums/enums.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package enums
2+
3+
type PrintMode string
4+
5+
const (
6+
TABLE_MODE PrintMode = "table"
7+
CSV_MODE PrintMode = "csv"
8+
)
9+
10+
type HistoryMode int
11+
12+
const (
13+
SingleHistory HistoryMode = 0
14+
PerDatabaseHistory HistoryMode = 1
15+
LocalHistory HistoryMode = 2
16+
)

pkg/shell/enums/printMode.go

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)