From 792d8a1484e9dfbbd4b0b57d84d964d735fea7b1 Mon Sep 17 00:00:00 2001 From: Stephen Cathcart Date: Wed, 29 Oct 2025 16:23:37 +0000 Subject: [PATCH 1/3] Fix and clarify vector type acceptance rules Valid vectors: - dbtype.Vector[T] where T is int8/16/32/64, float32/64, or their aliases - Type aliases: type MyVec = dbtype.Vector[int8] - Element aliases: type MyInt = int8; dbtype.Vector[MyInt] - Pointers: *dbtype.Vector[T] Invalid vectors: - New types wrapping vectors: type MyVec dbtype.Vector[int8] - New element types: type MyInt int8; dbtype.Vector[MyInt] Check regular slices serialize as lists, not vectors --- neo4j/dbtype/vector.go | 2 +- neo4j/internal/bolt/hydratedehydrate_test.go | 89 ++++++++++++++++++++ neo4j/internal/bolt/outgoing.go | 77 ++++++++++++++++- 3 files changed, 166 insertions(+), 2 deletions(-) diff --git a/neo4j/dbtype/vector.go b/neo4j/dbtype/vector.go index c5ca4e3e..591e0ce6 100644 --- a/neo4j/dbtype/vector.go +++ b/neo4j/dbtype/vector.go @@ -26,7 +26,7 @@ import ( // VectorElement represents the supported element types for Vector. type VectorElement interface { - ~float64 | ~float32 | ~int8 | ~int16 | ~int32 | ~int64 + float64 | float32 | int8 | int16 | int32 | int64 } // Vector represents a fixed-length array of numeric values. diff --git a/neo4j/internal/bolt/hydratedehydrate_test.go b/neo4j/internal/bolt/hydratedehydrate_test.go index df872c3a..b6cc849d 100644 --- a/neo4j/internal/bolt/hydratedehydrate_test.go +++ b/neo4j/internal/bolt/hydratedehydrate_test.go @@ -20,6 +20,7 @@ package bolt import ( "context" "net" + "reflect" "testing" "time" @@ -167,3 +168,91 @@ func TestDehydrateHydrate(ot *testing.T) { }) } } + +func TestVectorHandling(ot *testing.T) { + ot.Parallel() + + type myVecType = dbtype.Vector[int8] + type myVecType2 dbtype.Vector[float64] + type myVecType3 dbtype.Vector[int32] + type myCustomType int32 + + testCases := []struct { + name string + data any + shouldError bool + usePackV bool + }{ + {"Vector[int8]", dbtype.Vector[int8]{1, 2, 3}, false, false}, + {"Vector[int16]", dbtype.Vector[int16]{1, 2, 3}, false, false}, + {"Vector[int32]", dbtype.Vector[int32]{1, 2, 3}, false, false}, + {"Vector[int64]", dbtype.Vector[int64]{1, 2, 3}, false, false}, + {"Vector[float32]", dbtype.Vector[float32]{1.0, 2.0, 3.0}, false, false}, + {"Vector[float64]", dbtype.Vector[float64]{1.0, 2.0, 3.0}, false, false}, + + {"type alias", myVecType{1, 2, 3}, false, false}, + + {"new type int32", myVecType3{1, 2, 3}, true, false}, + {"new type float64", myVecType2{1.0, 2.0, 3.0}, true, false}, + + {"*Vector[int8]", &dbtype.Vector[int8]{1, 2, 3}, false, false}, + {"*Vector[float64]", &dbtype.Vector[float64]{1.0, 2.0, 3.0}, false, false}, + {"*type alias", &myVecType{1, 2, 3}, false, false}, + + {"*new type", &myVecType2{1.0, 2.0, 3.0}, true, false}, + + {"nil *Vector", (*dbtype.Vector[int8])(nil), false, false}, + + {"[]int8", []int8{1, 2, 3}, false, false}, + {"[]float64", []float64{1.0, 2.0, 3.0}, false, false}, + {"*[]int8", &[]int8{1, 2, 3}, false, false}, + + {"empty Vector", dbtype.Vector[int8]{}, false, false}, + {"empty *Vector", &dbtype.Vector[int8]{}, false, false}, + + {"[]int8 via packV", []int8{1, 2, 3}, false, true}, + {"[]float64 via packV", []float64{1.0, 2.0, 3.0}, false, true}, + {"[]string via packV", []string{"a", "b", "c"}, false, true}, + {"[]any via packV", []any{1, "hello", 3.14}, false, true}, + {"[]byte via packV", []byte{1, 2, 3}, false, true}, + + {"custom type", myCustomType(42), false, false}, + {"[]custom type", []myCustomType{1, 2, 3}, false, false}, + } + + for _, tc := range testCases { + ot.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var packErr error + out := &outgoing{ + chunker: newChunker(), + packer: packstream.Packer{}, + onPackErr: func(err error) { + packErr = err + }, + onIoErr: func(_ context.Context, err error) { + t.Errorf("Unexpected io error: %s", err) + }, + } + + if tc.usePackV { + out.packV(reflect.ValueOf(tc.data)) + } else { + out.packX(tc.data) + } + + if tc.shouldError { + if packErr == nil { + t.Errorf("Expected error for %s, but got none", tc.name) + } else if _, ok := packErr.(*db.UnsupportedTypeError); !ok { + t.Errorf("Expected UnsupportedTypeError for %s, but got: %T", tc.name, packErr) + } + } else { + if packErr != nil { + t.Errorf("Unexpected error for %s: %v", tc.name, packErr) + } + } + }) + } +} diff --git a/neo4j/internal/bolt/outgoing.go b/neo4j/internal/bolt/outgoing.go index 003a9fba..a7959fd7 100644 --- a/neo4j/internal/bolt/outgoing.go +++ b/neo4j/internal/bolt/outgoing.go @@ -386,12 +386,47 @@ func (o *outgoing) packX(x any) { return } // Inspect what the pointer points to - i := reflect.Indirect(v) + i := v.Elem() + switch i.Kind() { case reflect.Struct: o.packStruct(x) + return + + case reflect.Slice: + t := i.Type() + + // Pack exact vector types + if isExactVector(t) { + switch t { + case vecInt8T: + o.packer.VectorInt8(i.Interface().(dbtype.Vector[int8])) + case vecInt16T: + o.packer.VectorInt16(i.Interface().(dbtype.Vector[int16])) + case vecInt32T: + o.packer.VectorInt32(i.Interface().(dbtype.Vector[int32])) + case vecInt64T: + o.packer.VectorInt64(i.Interface().(dbtype.Vector[int64])) + case vecFloat32T: + o.packer.VectorFloat32(i.Interface().(dbtype.Vector[float32])) + case vecFloat64T: + o.packer.VectorFloat64(i.Interface().(dbtype.Vector[float64])) + } + return + } + + // Reject user-defined vector types + if convertibleToAnyVector(t) && !isExactVector(t) { + o.onPackErr(&db.UnsupportedTypeError{Type: t}) + return + } + + o.packV(i) + return + default: o.packV(i) + return } case reflect.Struct: o.packStruct(x) @@ -426,6 +461,12 @@ func (o *outgoing) packX(x any) { o.packX(e) } default: + // Reject user-defined vector types + if convertibleToAnyVector(v.Type()) && !isExactVector(v.Type()) { + o.onPackErr(&db.UnsupportedTypeError{Type: v.Type()}) + return + } + num := v.Len() o.packer.ArrayHeader(num) for i := 0; i < num; i++ { @@ -478,6 +519,40 @@ func typeForPrimitive[T any]() reflect.Type { return reflect.TypeOf(v) } +// Supported vector types +var ( + vecInt8T = reflect.TypeOf(dbtype.Vector[int8]{}) + vecInt16T = reflect.TypeOf(dbtype.Vector[int16]{}) + vecInt32T = reflect.TypeOf(dbtype.Vector[int32]{}) + vecInt64T = reflect.TypeOf(dbtype.Vector[int64]{}) + vecFloat32T = reflect.TypeOf(dbtype.Vector[float32]{}) + vecFloat64T = reflect.TypeOf(dbtype.Vector[float64]{}) +) + +// isExactVector checks if t is exactly a dbtype.Vector[T] +func isExactVector(t reflect.Type) bool { + switch t { + case vecInt8T, vecInt16T, vecInt32T, vecInt64T, vecFloat32T, vecFloat64T: + return true + } + return false +} + +// convertibleToAnyVector checks if t is a user-defined type convertible to a vector +func convertibleToAnyVector(t reflect.Type) bool { + // Only user-defined types are considered convertible + if t.PkgPath() == "" { + return false + } + + return t.ConvertibleTo(vecInt8T) || + t.ConvertibleTo(vecInt16T) || + t.ConvertibleTo(vecInt32T) || + t.ConvertibleTo(vecInt64T) || + t.ConvertibleTo(vecFloat32T) || + t.ConvertibleTo(vecFloat64T) +} + var intT = typeForPrimitive[int]() var int64T = typeForPrimitive[int64]() var stringT = typeForPrimitive[string]() From 526b02fea91e6a52c028cd7b225e0d1789aca1cf Mon Sep 17 00:00:00 2001 From: Stephen Cathcart Date: Wed, 29 Oct 2025 21:51:56 +0000 Subject: [PATCH 2/3] Removed redundant checks for user-defined vector types. --- neo4j/internal/bolt/outgoing.go | 75 ++++++++------------------------- 1 file changed, 17 insertions(+), 58 deletions(-) diff --git a/neo4j/internal/bolt/outgoing.go b/neo4j/internal/bolt/outgoing.go index a7959fd7..4588de22 100644 --- a/neo4j/internal/bolt/outgoing.go +++ b/neo4j/internal/bolt/outgoing.go @@ -387,46 +387,35 @@ func (o *outgoing) packX(x any) { } // Inspect what the pointer points to i := v.Elem() - switch i.Kind() { case reflect.Struct: o.packStruct(x) - return - case reflect.Slice: t := i.Type() - // Pack exact vector types - if isExactVector(t) { - switch t { - case vecInt8T: - o.packer.VectorInt8(i.Interface().(dbtype.Vector[int8])) - case vecInt16T: - o.packer.VectorInt16(i.Interface().(dbtype.Vector[int16])) - case vecInt32T: - o.packer.VectorInt32(i.Interface().(dbtype.Vector[int32])) - case vecInt64T: - o.packer.VectorInt64(i.Interface().(dbtype.Vector[int64])) - case vecFloat32T: - o.packer.VectorFloat32(i.Interface().(dbtype.Vector[float32])) - case vecFloat64T: - o.packer.VectorFloat64(i.Interface().(dbtype.Vector[float64])) - } + switch t { + case vecInt8T: + o.packer.VectorInt8(i.Interface().(dbtype.Vector[int8])) return - } - - // Reject user-defined vector types - if convertibleToAnyVector(t) && !isExactVector(t) { - o.onPackErr(&db.UnsupportedTypeError{Type: t}) + case vecInt16T: + o.packer.VectorInt16(i.Interface().(dbtype.Vector[int16])) + return + case vecInt32T: + o.packer.VectorInt32(i.Interface().(dbtype.Vector[int32])) + return + case vecInt64T: + o.packer.VectorInt64(i.Interface().(dbtype.Vector[int64])) + return + case vecFloat32T: + o.packer.VectorFloat32(i.Interface().(dbtype.Vector[float32])) + return + case vecFloat64T: + o.packer.VectorFloat64(i.Interface().(dbtype.Vector[float64])) return } - o.packV(i) - return - default: o.packV(i) - return } case reflect.Struct: o.packStruct(x) @@ -461,12 +450,6 @@ func (o *outgoing) packX(x any) { o.packX(e) } default: - // Reject user-defined vector types - if convertibleToAnyVector(v.Type()) && !isExactVector(v.Type()) { - o.onPackErr(&db.UnsupportedTypeError{Type: v.Type()}) - return - } - num := v.Len() o.packer.ArrayHeader(num) for i := 0; i < num; i++ { @@ -529,30 +512,6 @@ var ( vecFloat64T = reflect.TypeOf(dbtype.Vector[float64]{}) ) -// isExactVector checks if t is exactly a dbtype.Vector[T] -func isExactVector(t reflect.Type) bool { - switch t { - case vecInt8T, vecInt16T, vecInt32T, vecInt64T, vecFloat32T, vecFloat64T: - return true - } - return false -} - -// convertibleToAnyVector checks if t is a user-defined type convertible to a vector -func convertibleToAnyVector(t reflect.Type) bool { - // Only user-defined types are considered convertible - if t.PkgPath() == "" { - return false - } - - return t.ConvertibleTo(vecInt8T) || - t.ConvertibleTo(vecInt16T) || - t.ConvertibleTo(vecInt32T) || - t.ConvertibleTo(vecInt64T) || - t.ConvertibleTo(vecFloat32T) || - t.ConvertibleTo(vecFloat64T) -} - var intT = typeForPrimitive[int]() var int64T = typeForPrimitive[int64]() var stringT = typeForPrimitive[string]() From a1c78286cf1b5432eba9b1975d974254b2d8f091 Mon Sep 17 00:00:00 2001 From: Stephen Cathcart Date: Wed, 29 Oct 2025 22:13:09 +0000 Subject: [PATCH 3/3] Simplified test cases --- neo4j/internal/bolt/hydratedehydrate_test.go | 73 +++++++------------- 1 file changed, 26 insertions(+), 47 deletions(-) diff --git a/neo4j/internal/bolt/hydratedehydrate_test.go b/neo4j/internal/bolt/hydratedehydrate_test.go index b6cc849d..b695079e 100644 --- a/neo4j/internal/bolt/hydratedehydrate_test.go +++ b/neo4j/internal/bolt/hydratedehydrate_test.go @@ -20,7 +20,6 @@ package bolt import ( "context" "net" - "reflect" "testing" "time" @@ -178,46 +177,38 @@ func TestVectorHandling(ot *testing.T) { type myCustomType int32 testCases := []struct { - name string - data any - shouldError bool - usePackV bool + name string + data any }{ - {"Vector[int8]", dbtype.Vector[int8]{1, 2, 3}, false, false}, - {"Vector[int16]", dbtype.Vector[int16]{1, 2, 3}, false, false}, - {"Vector[int32]", dbtype.Vector[int32]{1, 2, 3}, false, false}, - {"Vector[int64]", dbtype.Vector[int64]{1, 2, 3}, false, false}, - {"Vector[float32]", dbtype.Vector[float32]{1.0, 2.0, 3.0}, false, false}, - {"Vector[float64]", dbtype.Vector[float64]{1.0, 2.0, 3.0}, false, false}, + {"Vector[int8]", dbtype.Vector[int8]{1, 2, 3}}, + {"Vector[int16]", dbtype.Vector[int16]{1, 2, 3}}, + {"Vector[int32]", dbtype.Vector[int32]{1, 2, 3}}, + {"Vector[int64]", dbtype.Vector[int64]{1, 2, 3}}, + {"Vector[float32]", dbtype.Vector[float32]{1.0, 2.0, 3.0}}, + {"Vector[float64]", dbtype.Vector[float64]{1.0, 2.0, 3.0}}, - {"type alias", myVecType{1, 2, 3}, false, false}, + {"type alias", myVecType{1, 2, 3}}, - {"new type int32", myVecType3{1, 2, 3}, true, false}, - {"new type float64", myVecType2{1.0, 2.0, 3.0}, true, false}, + {"new type int32", myVecType3{1, 2, 3}}, + {"new type float64", myVecType2{1.0, 2.0, 3.0}}, - {"*Vector[int8]", &dbtype.Vector[int8]{1, 2, 3}, false, false}, - {"*Vector[float64]", &dbtype.Vector[float64]{1.0, 2.0, 3.0}, false, false}, - {"*type alias", &myVecType{1, 2, 3}, false, false}, + {"*Vector[int8]", &dbtype.Vector[int8]{1, 2, 3}}, + {"*Vector[float64]", &dbtype.Vector[float64]{1.0, 2.0, 3.0}}, + {"*type alias", &myVecType{1, 2, 3}}, - {"*new type", &myVecType2{1.0, 2.0, 3.0}, true, false}, + {"*new type", &myVecType2{1.0, 2.0, 3.0}}, - {"nil *Vector", (*dbtype.Vector[int8])(nil), false, false}, + {"nil *Vector", (*dbtype.Vector[int8])(nil)}, - {"[]int8", []int8{1, 2, 3}, false, false}, - {"[]float64", []float64{1.0, 2.0, 3.0}, false, false}, - {"*[]int8", &[]int8{1, 2, 3}, false, false}, + {"[]int8", []int8{1, 2, 3}}, + {"[]float64", []float64{1.0, 2.0, 3.0}}, + {"*[]int8", &[]int8{1, 2, 3}}, - {"empty Vector", dbtype.Vector[int8]{}, false, false}, - {"empty *Vector", &dbtype.Vector[int8]{}, false, false}, + {"empty Vector", dbtype.Vector[int8]{}}, + {"empty *Vector", &dbtype.Vector[int8]{}}, - {"[]int8 via packV", []int8{1, 2, 3}, false, true}, - {"[]float64 via packV", []float64{1.0, 2.0, 3.0}, false, true}, - {"[]string via packV", []string{"a", "b", "c"}, false, true}, - {"[]any via packV", []any{1, "hello", 3.14}, false, true}, - {"[]byte via packV", []byte{1, 2, 3}, false, true}, - - {"custom type", myCustomType(42), false, false}, - {"[]custom type", []myCustomType{1, 2, 3}, false, false}, + {"custom type", myCustomType(42)}, + {"[]custom type", []myCustomType{1, 2, 3}}, } for _, tc := range testCases { @@ -236,22 +227,10 @@ func TestVectorHandling(ot *testing.T) { }, } - if tc.usePackV { - out.packV(reflect.ValueOf(tc.data)) - } else { - out.packX(tc.data) - } + out.packX(tc.data) - if tc.shouldError { - if packErr == nil { - t.Errorf("Expected error for %s, but got none", tc.name) - } else if _, ok := packErr.(*db.UnsupportedTypeError); !ok { - t.Errorf("Expected UnsupportedTypeError for %s, but got: %T", tc.name, packErr) - } - } else { - if packErr != nil { - t.Errorf("Unexpected error for %s: %v", tc.name, packErr) - } + if packErr != nil { + t.Errorf("Unexpected error for %s: %v", tc.name, packErr) } }) }