Skip to content

Commit 5cc9ad0

Browse files
committed
feat(statement-result): populate query field on success
1 parent 675c70d commit 5cc9ad0

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

internal/db/db.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ type StatementResult struct {
4545
Query string
4646
}
4747

48-
func newStatementResult(columnNames []string, rowCh chan rowResult) *StatementResult {
49-
return &StatementResult{ColumnNames: columnNames, RowCh: rowCh}
48+
func newStatementResult(columnNames []string, rowCh chan rowResult, query string) *StatementResult {
49+
return &StatementResult{ColumnNames: columnNames, RowCh: rowCh, Query: query}
5050
}
5151

5252
func newStatementResultWithError(err error) *StatementResult {
@@ -173,7 +173,7 @@ func (db *Db) executeQuery(query string, statementResultCh chan StatementResult)
173173

174174
defer rows.Close()
175175

176-
return readQueryResults(rows, statementResultCh)
176+
return readQueryResults(rows, statementResultCh, query)
177177
}
178178

179179
func (db *Db) prepareStatementsIntoQueries(statementsString string) []string {
@@ -221,14 +221,17 @@ func getColumnTypes(rows *sql.Rows) ([]reflect.Type, error) {
221221
return types, nil
222222
}
223223

224-
func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResult) (shouldContinue bool) {
224+
func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResult, query string) (shouldContinue bool) {
225+
queries, _ := sqliteparserutils.SplitStatement(query)
225226
hasResultSetToRead := true
226227
for hasResultSetToRead {
227-
if shouldContinue := readQueryResultSet(queryRows, statementResultCh); !shouldContinue {
228-
return false
229-
}
228+
for _, query := range queries {
229+
if shouldContinue := readQueryResultSet(queryRows, statementResultCh, query); !shouldContinue {
230+
return false
231+
}
230232

231-
hasResultSetToRead = queryRows.NextResultSet()
233+
hasResultSetToRead = queryRows.NextResultSet()
234+
}
232235
}
233236

234237
if err := queryRows.Err(); err != nil {
@@ -239,7 +242,7 @@ func readQueryResults(queryRows *sql.Rows, statementResultCh chan StatementResul
239242
return true
240243
}
241244

242-
func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementResult) (shouldContinue bool) {
245+
func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementResult, query string) (shouldContinue bool) {
243246
columnNames, err := getColumnNames(queryRows)
244247
if err != nil {
245248
statementResultCh <- *newStatementResultWithError(err)
@@ -265,7 +268,7 @@ func readQueryResultSet(queryRows *sql.Rows, statementResultCh chan StatementRes
265268
rowCh := make(chan rowResult)
266269
defer close(rowCh)
267270

268-
statementResultCh <- *newStatementResult(columnNames, rowCh)
271+
statementResultCh <- *newStatementResult(columnNames, rowCh, query)
269272

270273
for queryRows.Next() {
271274
err = queryRows.Scan(columnPointers...)

0 commit comments

Comments
 (0)