Skip to content

Commit 86209a7

Browse files
committed
Add check in api crate for dimension mismatch while insertion; Add tests for the same
1 parent 048ed39 commit 86209a7

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

crates/api/src/lib.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,27 @@ fn generate_point_id() -> u64 {
1919
pub struct VectorDb {
2020
storage: Arc<dyn StorageEngine>,
2121
index: Arc<RwLock<dyn VectorIndex>>, // Using a RwLock instead of Mutex to improve concurrency
22+
dimension: usize,
2223
}
2324

2425
impl VectorDb {
25-
fn _new(storage: Arc<dyn StorageEngine>, index: Arc<RwLock<dyn VectorIndex>>) -> Self {
26-
Self { storage, index }
26+
fn _new(
27+
storage: Arc<dyn StorageEngine>,
28+
index: Arc<RwLock<dyn VectorIndex>>,
29+
dimension: usize,
30+
) -> Self {
31+
Self {
32+
storage,
33+
index,
34+
dimension,
35+
}
2736
}
2837

2938
//TODO: Make this an atomic operation
3039
pub fn insert(&self, vector: DenseVector, payload: Payload) -> Result<PointId, DbError> {
40+
if vector.len() != self.dimension {
41+
return Err(DbError::DimensionMismatch);
42+
}
3143
// Generate a new point id
3244
let point_id = generate_point_id();
3345
self.storage
@@ -105,7 +117,7 @@ pub fn init_api(config: DbConfig) -> Result<VectorDb, DbError> {
105117
};
106118

107119
// Init the db
108-
let db = VectorDb::_new(storage, index);
120+
let db = VectorDb::_new(storage, index, config.dimension);
109121

110122
Ok(db)
111123
}
@@ -147,6 +159,22 @@ mod tests {
147159
assert_eq!(point.payload.as_ref().unwrap(), &payload);
148160
}
149161

162+
#[test]
163+
fn test_dimension_mismatch() {
164+
let db = create_test_db();
165+
let v1 = vec![1.0, 2.0, 3.0];
166+
let v2 = vec![1.0, 2.0];
167+
let payload = Payload {};
168+
169+
let res1 = db.insert(v1, payload);
170+
assert!(res1.is_ok());
171+
172+
// Insert vector of dimension 2 != 3
173+
let res2 = db.insert(v2, payload);
174+
assert!(res2.is_err());
175+
assert_eq!(res2.unwrap_err(), DbError::DimensionMismatch);
176+
}
177+
150178
#[test]
151179
fn test_delete() {
152180
let db = create_test_db();

crates/defs/src/error.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
#[derive(Debug)]
1+
#[derive(Debug, PartialEq, Eq)]
22
pub enum DbError {
33
ParseError,
44
StorageError(String),
55
SerializationError(String),
66
DeserializationError,
77
IndexError(String),
88
LockError,
9+
DimensionMismatch,
910
}

crates/grpc_server/.sample.env

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
GRPC_SERVER_ROOT_PASSWORD=123 // required
2-
GRPC_SERVER_DIMENSION=3 // required
1+
GRPC_SERVER_ROOT_PASSWORD=123 # required
2+
GRPC_SERVER_DIMENSION=3 # required
33

4-
GRPC_SERVER_HOST=localhost // defaults to 127.0.0.1 aka localhost
5-
GRPC_SERVER_PORT=8080 // defaults to 8080
6-
GRPC_SERVER_STORAGE_TYPE=inmemory // (inmemory/rocksdb) defaults to 'inmemory'
7-
GRPC_SERVER_INDEX_TYPE=flat // defaults to flat
8-
GRPC_SERVER_DATA_PATH=data // defaults to a temporary directory
9-
GRPC_SERVER_LOGGING=true // defaults to true
4+
GRPC_SERVER_HOST=localhost # defaults to 127.0.0.1 aka localhost
5+
GRPC_SERVER_PORT=8080 # defaults to 8080
6+
GRPC_SERVER_STORAGE_TYPE=inmemory # (inmemory/rocksdb) defaults to 'inmemory'
7+
GRPC_SERVER_INDEX_TYPE=flat # defaults to flat
8+
GRPC_SERVER_DATA_PATH=data # defaults to a temporary directory
9+
GRPC_SERVER_LOGGING=true # defaults to true

crates/grpc_server/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)