Skip to content

Commit 5cb6fc6

Browse files
tburginolavloite
andauthored
types: support type aliases (#316)
* types: support type aliases * chore: remove unused code --------- Co-authored-by: Knut Olav Løite <koloite@gmail.com>
1 parent 0d4bfee commit 5cb6fc6

File tree

3 files changed

+16
-26
lines changed

3 files changed

+16
-26
lines changed

driver.go

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,11 +1038,12 @@ func (c *conn) CheckNamedValue(value *driver.NamedValue) error {
10381038
if checkIsValidType(value.Value) {
10391039
return nil
10401040
}
1041-
if valuer, ok := value.Value.(driver.Valuer); ok {
1042-
v, err := callValuerValue(valuer)
1043-
if err != nil {
1044-
return err
1045-
}
1041+
1042+
// Convert the value using the default sql driver. This uses driver.Valuer,
1043+
// if implemented, and falls back to reflection. If the converted value is
1044+
// a supported spanner type, use it. Otherwise, ignore any errors and
1045+
// continue checking other supported spanner specific types.
1046+
if v, err := driver.DefaultParameterConverter.ConvertValue(value.Value); err == nil {
10461047
if checkIsValidType(v) {
10471048
value.Value = v
10481049
return nil
@@ -1252,27 +1253,6 @@ func (c *conn) createPartitionedDmlQueryOptions() spanner.QueryOptions {
12521253
return spanner.QueryOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams}
12531254
}
12541255

1255-
// callValuerValue is from Go's database/sql package to handle a special case,
1256-
// in the same way that database/sql does, for nil pointers to types that implement
1257-
// driver.Valuer with value receivers.
1258-
1259-
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
1260-
1261-
// callValuerValue returns vr.Value(), with one exception:
1262-
// If vr.Value is a value receiver method on a pointer type and the pointer is nil,
1263-
// it would panic at runtime. This treats it like nil instead.
1264-
//
1265-
// This is so people can implement driver.Value on value types and still use nil pointers
1266-
// to those types to mean nil/NULL, just like the Go database/sql package.
1267-
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
1268-
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
1269-
rv.IsNil() &&
1270-
rv.Type().Elem().Implements(valuerReflectType) {
1271-
return nil, nil
1272-
}
1273-
return vr.Value()
1274-
}
1275-
12761256
/* The following is the same implementation as in google-cloud-go/spanner */
12771257

12781258
func isStructOrArrayOfStructValue(v interface{}) bool {

driver_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package spannerdriver
1717
import (
1818
"context"
1919
"database/sql/driver"
20+
"net"
2021
"testing"
2122
"time"
2223

@@ -494,6 +495,11 @@ func TestConn_CheckNamedValue(t *testing.T) {
494495
{in: &Person{Name: "hello", Age: 123}, want: &Person{Name: "hello", Age: 123}},
495496
// nil pointer of type that implements driver.Valuer via value receiver should use nil
496497
{in: testNil, want: nil},
498+
// net.IP reflects to []byte. Allow model structs to have fields with
499+
// types that reflect to types supported by spanner.
500+
{in: net.IPv6loopback, want: []byte(net.IPv6loopback)},
501+
// Similarly, time.Duration is just an int64.
502+
{in: time.Duration(1), want: int64(time.Duration(1))},
497503
}
498504

499505
for _, test := range tests {

integration_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"math"
2525
"math/big"
2626
"math/rand"
27+
"net"
2728
"os"
2829
"reflect"
2930
"strconv"
@@ -505,6 +506,9 @@ func TestTypeRoundtrip(t *testing.T) {
505506
// JSON variants
506507
{in: spanner.NullJSON{Valid: true, Value: map[string]any{"a": 13}}, skipeq: true},
507508
{in: []spanner.NullJSON{{Valid: true, Value: map[string]any{"a": 13}}}, skipeq: true},
509+
// Standard library type alias examples
510+
{in: net.IPv6loopback},
511+
{in: time.Duration(1)},
508512
}
509513

510514
for _, test := range tests {

0 commit comments

Comments
 (0)