Skip to content

Commit b22394d

Browse files
committed
Implement remaining rpc tests; Minor bug fixes
1 parent b4f2eb0 commit b22394d

File tree

7 files changed

+154
-30
lines changed

7 files changed

+154
-30
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/defs/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,11 @@ pub enum DbError {
88
LockError,
99
DimensionMismatch,
1010
}
11+
12+
impl std::fmt::Display for DbError {
13+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
14+
write!(f, "{:?}", self)
15+
}
16+
}
17+
18+
impl std::error::Error for DbError {}

crates/grpc_server/README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ cargo run --bin grpc_server
1313
Use the [.sample.env](.sample.env) shown below as a reference to set your environment variables in a `.env` file.
1414

1515
```bash
16-
GRPC_SERVER_ROOT_PASSWORD=123 // required
17-
GRPC_SERVER_DIMENSION=3 // required
18-
19-
GRPC_SERVER_HOST=localhost // defaults to 127.0.0.1 aka localhost
20-
GRPC_SERVER_PORT=8080 // defaults to 8080
21-
GRPC_SERVER_STORAGE_TYPE=inmemory // (inmemory/rocksdb) defaults to inmemory
22-
GRPC_SERVER_INDEX_TYPE=flat // defaults to flat
23-
GRPC_SERVER_DATA_PATH=data // defaults to a temporary directory
24-
GRPC_SERVER_LOGGING=true // defaults to true
16+
GRPC_SERVER_ROOT_PASSWORD=123 # required
17+
GRPC_SERVER_DIMENSION=3 # required
18+
19+
GRPC_SERVER_HOST=localhost # defaults to 127.0.0.1 aka localhost
20+
GRPC_SERVER_PORT=8080 # defaults to 8080
21+
GRPC_SERVER_STORAGE_TYPE=inmemory # (inmemory/rocksdb) defaults to 'inmemory'
22+
GRPC_SERVER_INDEX_TYPE=flat # defaults to flat
23+
GRPC_SERVER_DATA_PATH=data # defaults to a temporary directory
24+
GRPC_SERVER_LOGGING=true # defaults to true
25+
2526

2627
```
2728

crates/grpc_server/src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ impl GRPCServerConfig {
110110
let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
111111

112112
// check if logging is enabled
113-
let mut logging: bool = true; // default to no logging
113+
let mut logging: bool = true; // default to logging enabled
114114
if let Ok(logging_str) = env::var(ENV_LOGGING) {
115115
logging = logging_str.parse().unwrap_or(true);
116116
}

crates/grpc_server/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use std::panic;
55
#[tokio::main]
66
async fn main() -> Result<(), Box<dyn std::error::Error>> {
77
tracing_subscriber::fmt::init();
8+
9+
// load config from environment from environment variables
810
let config = GRPCServerConfig::load_config()
911
.inspect_err(|err| panic!("Failed to load config: {}", err))
1012
.unwrap();

crates/grpc_server/src/service.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::constants::SIMILARITY_PROTOBUFF_MAP;
22
use crate::interceptors;
33
use crate::utils::log_rpc;
4-
use std::{net::SocketAddr, panic};
4+
use std::net::SocketAddr;
55
use tonic::{Request, Response, Status, service::InterceptorLayer, transport::Server};
66
use tracing::{Level, event};
77
use vectordb::{
@@ -82,12 +82,14 @@ impl VectorDb for VectorDBService {
8282
let query_vect = search_request
8383
.query_vector
8484
.ok_or(Status::invalid_argument("Invalid query_vector"))?;
85-
let similarity = SIMILARITY_PROTOBUFF_MAP[search_request.similarity as usize];
85+
let similarity = SIMILARITY_PROTOBUFF_MAP
86+
.get(search_request.similarity as usize)
87+
.ok_or(Status::internal("Invalid similarity"))?;
8688
let limit = search_request.limit;
8789

8890
let result_point_ids = self
8991
.vector_db
90-
.search(query_vect.values, similarity, limit as usize)
92+
.search(query_vect.values, *similarity, limit as usize)
9193
.map_err(|_| Status::internal("Internal server error"))?;
9294

9395
// create a mapped vector of PointIds
@@ -114,7 +116,7 @@ impl VectorDb for VectorDBService {
114116
Err(Status::not_found("Point not found"))
115117
}
116118
}
117-
Err(_) => Err(Status::not_found("Error deleting point")),
119+
Err(_) => Err(Status::internal("Error deleting point")),
118120
}
119121
}
120122
}
@@ -128,11 +130,14 @@ pub async fn run_server(
128130

129131
let auth_interceptor = interceptors::AuthInterceptor::new(root_password);
130132

131-
let _ = Server::builder()
133+
Server::builder()
132134
.layer(InterceptorLayer::new(auth_interceptor))
133135
.add_service(VectorDbServer::new(vector_db_service))
134136
.serve(addr)
135137
.await
136-
.inspect_err(|err| panic!("Failed to start gRPC server: {}", err));
138+
.map_err(|e| {
139+
event!(Level::ERROR, "Failed to start gRPC server: {}", e);
140+
Box::new(e)
141+
})?;
137142
Ok(())
138143
}

crates/grpc_server/src/tests.rs

Lines changed: 121 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::interceptors;
33
use crate::service::VectorDBService;
44
use crate::service::vectordb::vector_db_client::VectorDbClient;
55
use crate::service::vectordb::vector_db_server::VectorDbServer;
6-
use crate::service::vectordb::{DenseVector, InsertVectorRequest, PointId};
6+
use crate::service::vectordb::{DenseVector, InsertVectorRequest, PointId, SearchRequest};
77
use api;
88
use api::DbConfig;
99
use index::IndexType;
@@ -22,7 +22,7 @@ use tonic::transport::{Channel, Server};
2222
// - use a shared instance of the server
2323
// currently tests must be run with --test-threads=1
2424

25-
pub async fn run_test_server_with_stream(
25+
pub async fn run_server_with_stream(
2626
vector_db_service: VectorDBService,
2727
listener: TcpListener,
2828
root_password: String,
@@ -39,6 +39,7 @@ pub async fn run_test_server_with_stream(
3939
}
4040

4141
async fn start_test_server() -> Result<(), Box<dyn std::error::Error>> {
42+
// using a temporary directory for db datapath
4243
let temp_dir = tempdir().unwrap();
4344

4445
let db_config = DbConfig {
@@ -65,8 +66,7 @@ async fn start_test_server() -> Result<(), Box<dyn std::error::Error>> {
6566
let listener = tokio::net::TcpListener::bind(config.addr).await?;
6667

6768
tokio::spawn(async move {
68-
let _ =
69-
run_test_server_with_stream(vector_db_service, listener, config.root_password).await;
69+
let _ = run_server_with_stream(vector_db_service, listener, config.root_password).await;
7070
});
7171

7272
Ok(())
@@ -82,12 +82,14 @@ async fn create_test_client() -> Result<VectorDbClient<Channel>, Box<dyn std::er
8282
#[tokio::test]
8383
async fn test_grpc_server_start() {
8484
start_test_server().await.unwrap();
85-
8685
let mut client = create_test_client().await.unwrap();
8786

87+
// insert a test vector
88+
let test_vec = vec![1.0, 2.0, 3.0];
89+
8890
let mut request = tonic::Request::new(InsertVectorRequest {
8991
vector: Some(DenseVector {
90-
values: vec![1.0, 2.0, 3.0],
92+
values: test_vec.clone(),
9193
}),
9294
payload: Some(Struct::default()),
9395
});
@@ -103,6 +105,7 @@ async fn test_insert_vector_rpc() {
103105
start_test_server().await.unwrap();
104106
let mut client = create_test_client().await.unwrap();
105107

108+
// insert a test vector
106109
let test_vec = vec![1.0, 2.0, 3.0];
107110

108111
let mut request = tonic::Request::new(InsertVectorRequest {
@@ -116,10 +119,10 @@ async fn test_insert_vector_rpc() {
116119
.insert("authorization", "Bearer 123".parse().unwrap());
117120
let resp = client.insert_vector(request).await;
118121

119-
// Check if request is successful
122+
// check if request is successful
120123
assert!(resp.is_ok());
121124

122-
// Check if the vector is actually present in the database
125+
// check if the vector is actually present in the database
123126
let mut request = tonic::Request::new(PointId {
124127
id: resp.unwrap().into_inner().id,
125128
});
@@ -128,12 +131,12 @@ async fn test_insert_vector_rpc() {
128131
.insert("authorization", "Bearer 123".parse().unwrap());
129132
let resp = client.get_point(request).await;
130133

131-
// Check if request is successful
134+
// check if request is successful
132135
assert!(resp.is_ok());
133136
let point = resp.unwrap().into_inner();
134137
assert_eq!(point.vector.unwrap().values, test_vec);
135138

136-
// Insert a new vector with mismatched dimensions
139+
// insert a new vector with mismatched dimensions
137140
let mut request = tonic::Request::new(InsertVectorRequest {
138141
vector: Some(DenseVector {
139142
values: vec![1.0, 2.0],
@@ -145,15 +148,119 @@ async fn test_insert_vector_rpc() {
145148
.insert("authorization", "Bearer 123".parse().unwrap());
146149
let resp = client.insert_vector(request).await;
147150

148-
// Request must have failed
151+
// request must fail
149152
assert!(resp.is_err());
150153
}
151154

152155
#[tokio::test]
153-
async fn test_delete_vector_rpc() {}
156+
async fn test_delete_vector_rpc() {
157+
start_test_server().await.unwrap();
158+
let mut client = create_test_client().await.unwrap();
159+
160+
// insert a test vector
161+
let test_vec = vec![1.0, 2.0, 3.0];
162+
let mut request = tonic::Request::new(InsertVectorRequest {
163+
vector: Some(DenseVector {
164+
values: test_vec.clone(),
165+
}),
166+
payload: Some(Struct::default()),
167+
});
168+
request
169+
.metadata_mut()
170+
.insert("authorization", "Bearer 123".parse().unwrap());
171+
let resp = client.insert_vector(request).await;
172+
173+
// check if request is successful
174+
assert!(resp.is_ok());
175+
let point = resp.unwrap().into_inner();
176+
177+
// delete the vector
178+
let mut request = tonic::Request::new(PointId { id: point.id });
179+
request
180+
.metadata_mut()
181+
.insert("authorization", "Bearer 123".parse().unwrap());
182+
let resp = client.delete_point(request).await;
183+
184+
// check if request is successful
185+
assert!(resp.is_ok());
186+
187+
// verify that the vector is deleted
188+
let mut request = tonic::Request::new(PointId { id: point.id });
189+
request
190+
.metadata_mut()
191+
.insert("authorization", "Bearer 123".parse().unwrap());
192+
let resp = client.get_point(request).await;
193+
194+
// request must fail since the vector is deleted
195+
assert!(resp.is_err());
196+
}
154197

155198
#[tokio::test]
156-
async fn test_search_vector_rpc() {}
199+
async fn test_search_vector_rpc() {
200+
start_test_server().await.unwrap();
201+
let mut client = create_test_client().await.unwrap();
202+
203+
// insert a test vector
204+
let test_vec = vec![1.0, 2.0, 3.0];
205+
let mut request = tonic::Request::new(InsertVectorRequest {
206+
vector: Some(DenseVector {
207+
values: test_vec.clone(),
208+
}),
209+
payload: Some(Struct::default()),
210+
});
211+
request
212+
.metadata_mut()
213+
.insert("authorization", "Bearer 123".parse().unwrap());
214+
let resp = client.insert_vector(request).await;
215+
216+
// check if request is successful
217+
assert!(resp.is_ok());
218+
let point = resp.unwrap().into_inner();
219+
220+
let query_vec = vec![2.0, 2.0, 2.0];
221+
222+
// search for the vector
223+
let mut request = tonic::Request::new(SearchRequest {
224+
query_vector: Some(DenseVector {
225+
values: query_vec.clone(),
226+
}),
227+
similarity: 0, // euclidean distance
228+
limit: 1,
229+
});
230+
request
231+
.metadata_mut()
232+
.insert("authorization", "Bearer 123".parse().unwrap());
233+
let resp = client.search_points(request).await;
234+
235+
// check if request is successful
236+
assert!(resp.is_ok());
237+
let result = resp.unwrap().into_inner();
238+
239+
// 1 vector has to be returned
240+
assert_eq!(result.result_point_ids.len(), 1);
241+
242+
// check if the returned point id matches the inserted point id
243+
assert_eq!(result.result_point_ids[0], PointId { id: point.id });
244+
}
157245

158246
#[tokio::test]
159-
async fn test_unauthorized_rpc() {}
247+
async fn test_unauthorized_rpc() {
248+
start_test_server().await.unwrap();
249+
let mut client = create_test_client().await.unwrap();
250+
251+
// insert a test vector
252+
let test_vec = vec![1.0, 2.0, 3.0];
253+
let mut request = tonic::Request::new(InsertVectorRequest {
254+
vector: Some(DenseVector {
255+
values: test_vec.clone(),
256+
}),
257+
payload: Some(Struct::default()),
258+
});
259+
request
260+
.metadata_mut()
261+
.insert("authorization", "Bearer 43121".parse().unwrap()); // wrong auth token
262+
let resp = client.insert_vector(request).await;
263+
264+
// request must fail
265+
assert!(resp.is_err());
266+
}

0 commit comments

Comments
 (0)