Skip to content

Commit e2620e6

Browse files
authored
feat: add DirectExecuteQuery option (#455)
Adds an option to directly execute the query when QueryContext is called. Without this option, the query execution is delayed until the first call to Rows.Next. This again also delays any query errors until the first call to Rows.Next, which can be confusing in some cases.
1 parent 3083c93 commit e2620e6

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

conn.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"database/sql"
2020
"database/sql/driver"
21+
"errors"
2122
"log/slog"
2223
"slices"
2324
"time"
@@ -797,7 +798,16 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecO
797798
return nil, err
798799
}
799800
}
800-
return &rows{it: iter, decodeOption: execOptions.DecodeOption, decodeToNativeArrays: execOptions.DecodeToNativeArrays}, nil
801+
res := &rows{it: iter, decodeOption: execOptions.DecodeOption, decodeToNativeArrays: execOptions.DecodeToNativeArrays}
802+
if execOptions.DirectExecuteQuery {
803+
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
804+
res.getColumns()
805+
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
806+
_ = res.Close()
807+
return nil, res.dirtyErr
808+
}
809+
}
810+
return res, nil
801811
}
802812

803813
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {

driver.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ type ExecOptions struct {
175175
// AutoCommitDMLMode determines the type of transaction that DML statements
176176
// that are executed outside explicit transactions use.
177177
AutocommitDMLMode AutocommitDMLMode
178+
179+
// DirectExecute determines whether a query is executed directly when the
180+
// [sql.DB.QueryContext] method is called, or whether the actual query execution
181+
// is delayed until the first call to [sql.Rows.Next]. The default is to delay
182+
// the execution. Set this flag to true to execute the query directly when
183+
// [sql.DB.QueryContext] is called.
184+
DirectExecuteQuery bool
178185
}
179186

180187
type DecodeOption int

driver_with_mockserver_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,53 @@ func TestSimpleQuery(t *testing.T) {
235235
}
236236
}
237237

238+
func TestDirectExecuteQuery(t *testing.T) {
239+
t.Parallel()
240+
241+
db, server, teardown := setupTestDBConnection(t)
242+
defer teardown()
243+
244+
// This does not use DirectExecuteQuery. The query is only sent to Spanner when
245+
// rows.Next is called.
246+
rows, err := db.QueryContext(context.Background(), testutil.SelectFooFromBar)
247+
if err != nil {
248+
t.Fatal(err)
249+
}
250+
// There should be no request on the server.
251+
requests := drainRequestsFromServer(server.TestSpanner)
252+
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
253+
if g, w := len(sqlRequests), 0; g != w {
254+
t.Fatalf("sql requests count mismatch\n Got: %v\nWant: %v", g, w)
255+
}
256+
if !rows.Next() {
257+
t.Fatal("no rows")
258+
}
259+
// The request should now be present on the server.
260+
requests = drainRequestsFromServer(server.TestSpanner)
261+
sqlRequests = requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
262+
if g, w := len(sqlRequests), 1; g != w {
263+
t.Fatalf("sql requests count mismatch\n Got: %v\nWant: %v", g, w)
264+
}
265+
_ = rows.Close()
266+
267+
// Now repeat the same with the DirectExecuteQuery option.
268+
rows, err = db.QueryContext(context.Background(), testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true})
269+
if err != nil {
270+
t.Fatal(err)
271+
}
272+
// The request should be present on the server.
273+
requests = drainRequestsFromServer(server.TestSpanner)
274+
sqlRequests = requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
275+
if g, w := len(sqlRequests), 1; g != w {
276+
t.Fatalf("sql requests count mismatch\n Got: %v\nWant: %v", g, w)
277+
}
278+
// Verify that we can get the row that we selected.
279+
if !rows.Next() {
280+
t.Fatal("no rows")
281+
}
282+
_ = rows.Close()
283+
}
284+
238285
func TestConcurrentScanAndClose(t *testing.T) {
239286
t.Parallel()
240287

0 commit comments

Comments
 (0)