Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neo4j/dbtype/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

// VectorElement represents the supported element types for Vector.
type VectorElement interface {
~float64 | ~float32 | ~int8 | ~int16 | ~int32 | ~int64
float64 | float32 | int8 | int16 | int32 | int64
}

// Vector represents a fixed-length array of numeric values.
Expand Down
68 changes: 68 additions & 0 deletions neo4j/internal/bolt/hydratedehydrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,71 @@ func TestDehydrateHydrate(ot *testing.T) {
})
}
}

func TestVectorHandling(ot *testing.T) {
ot.Parallel()

type myVecType = dbtype.Vector[int8]
type myVecType2 dbtype.Vector[float64]
type myVecType3 dbtype.Vector[int32]
type myCustomType int32

testCases := []struct {
name string
data any
}{
{"Vector[int8]", dbtype.Vector[int8]{1, 2, 3}},
{"Vector[int16]", dbtype.Vector[int16]{1, 2, 3}},
{"Vector[int32]", dbtype.Vector[int32]{1, 2, 3}},
{"Vector[int64]", dbtype.Vector[int64]{1, 2, 3}},
{"Vector[float32]", dbtype.Vector[float32]{1.0, 2.0, 3.0}},
{"Vector[float64]", dbtype.Vector[float64]{1.0, 2.0, 3.0}},

{"type alias", myVecType{1, 2, 3}},

{"new type int32", myVecType3{1, 2, 3}},
{"new type float64", myVecType2{1.0, 2.0, 3.0}},

{"*Vector[int8]", &dbtype.Vector[int8]{1, 2, 3}},
{"*Vector[float64]", &dbtype.Vector[float64]{1.0, 2.0, 3.0}},
{"*type alias", &myVecType{1, 2, 3}},

{"*new type", &myVecType2{1.0, 2.0, 3.0}},

{"nil *Vector", (*dbtype.Vector[int8])(nil)},

{"[]int8", []int8{1, 2, 3}},
{"[]float64", []float64{1.0, 2.0, 3.0}},
{"*[]int8", &[]int8{1, 2, 3}},

{"empty Vector", dbtype.Vector[int8]{}},
{"empty *Vector", &dbtype.Vector[int8]{}},

{"custom type", myCustomType(42)},
{"[]custom type", []myCustomType{1, 2, 3}},
}

for _, tc := range testCases {
ot.Run(tc.name, func(t *testing.T) {
t.Parallel()

var packErr error
out := &outgoing{
chunker: newChunker(),
packer: packstream.Packer{},
onPackErr: func(err error) {
packErr = err
},
onIoErr: func(_ context.Context, err error) {
t.Errorf("Unexpected io error: %s", err)
},
}

out.packX(tc.data)

if packErr != nil {
t.Errorf("Unexpected error for %s: %v", tc.name, packErr)
}
})
}
}
36 changes: 35 additions & 1 deletion neo4j/internal/bolt/outgoing.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,34 @@ func (o *outgoing) packX(x any) {
return
}
// Inspect what the pointer points to
i := reflect.Indirect(v)
i := v.Elem()
switch i.Kind() {
case reflect.Struct:
o.packStruct(x)
case reflect.Slice:
t := i.Type()
// Pack exact vector types
switch t {
case vecInt8T:
o.packer.VectorInt8(i.Interface().(dbtype.Vector[int8]))
return
case vecInt16T:
o.packer.VectorInt16(i.Interface().(dbtype.Vector[int16]))
return
case vecInt32T:
o.packer.VectorInt32(i.Interface().(dbtype.Vector[int32]))
return
case vecInt64T:
o.packer.VectorInt64(i.Interface().(dbtype.Vector[int64]))
return
case vecFloat32T:
o.packer.VectorFloat32(i.Interface().(dbtype.Vector[float32]))
return
case vecFloat64T:
o.packer.VectorFloat64(i.Interface().(dbtype.Vector[float64]))
return
}
o.packV(i)
default:
o.packV(i)
}
Expand Down Expand Up @@ -478,6 +502,16 @@ func typeForPrimitive[T any]() reflect.Type {
return reflect.TypeOf(v)
}

// Supported vector types
var (
vecInt8T = reflect.TypeOf(dbtype.Vector[int8]{})
vecInt16T = reflect.TypeOf(dbtype.Vector[int16]{})
vecInt32T = reflect.TypeOf(dbtype.Vector[int32]{})
vecInt64T = reflect.TypeOf(dbtype.Vector[int64]{})
vecFloat32T = reflect.TypeOf(dbtype.Vector[float32]{})
vecFloat64T = reflect.TypeOf(dbtype.Vector[float64]{})
)

var intT = typeForPrimitive[int]()
var int64T = typeForPrimitive[int64]()
var stringT = typeForPrimitive[string]()
Expand Down