Skip to content

Commit f01d3fa

Browse files
committed
Implement auth interceptor
1 parent 9658b97 commit f01d3fa

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed
Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,46 @@
1-
use crate::config;
2-
use std::fs::metadata;
3-
use tonic::{GrpcMethod, Request, Response, Status, service::Interceptor};
4-
use tracing::{Level, Metadata, event};
1+
use tonic::{GrpcMethod, Request, Status, service::Interceptor};
2+
use tracing::{Level, event};
53

64
pub fn logging_interceptor(req: Request<()>) -> Result<Request<()>, Status> {
75
if let Some(method) = req.extensions().get::<GrpcMethod>() {
8-
event!(
9-
Level::INFO,
10-
"RPC call: {}:{} | from: {}",
6+
println!(
7+
"RPC to {}:{} : From {:?}",
118
method.service(),
129
method.method(),
13-
req.remote_addr().unwrap()
10+
req.remote_addr()
1411
);
1512
}
1613
Ok(req)
1714
}
1815

19-
//TODO: implement auth
16+
#[derive(Clone)]
17+
pub struct AuthInterceptor {
18+
root_password: String,
19+
}
2020

21-
// struct AuthInterceptor {
22-
// config: &'static config::GRPCServerConfig,
23-
// }
21+
impl Interceptor for AuthInterceptor {
22+
fn call(&mut self, req: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
23+
let auth_token = match req.metadata().get("authorization") {
24+
Some(t) => t,
25+
None => return Err(Status::unauthenticated("Invalid credentials")),
26+
};
27+
if auth_token
28+
.to_str()
29+
.unwrap_or_default()
30+
.strip_prefix("Bearer ")
31+
.unwrap_or_default()
32+
== self.root_password
33+
{
34+
Ok(req)
35+
} else {
36+
event!(Level::WARN, "Unauthorized Request");
37+
Err(Status::unauthenticated("Invalid credentials"))
38+
}
39+
}
40+
}
2441

25-
// impl Interceptor for AuthInterceptor {
26-
// fn call(&mut self, req: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
27-
// if let req.metadata().unwrap() == self.config.root_password {
28-
// };
29-
// Ok(req)
30-
// }
31-
// }
42+
impl AuthInterceptor {
43+
pub fn new(root_password: String) -> AuthInterceptor {
44+
AuthInterceptor { root_password }
45+
}
46+
}

crates/grpc_server/src/main.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
use std::panic;
2-
31
use api;
42
use defs;
5-
use grpc_server::config;
63
use grpc_server::constants::SIMILARITY_PROTOBUF_MAP;
74
use grpc_server::interceptors;
5+
use grpc_server::{config::GRPCServerConfig, interceptors::logging_interceptor};
6+
use std::panic;
87
use tonic::{Request, Response, Status, service::InterceptorLayer, transport::Server};
98
use tracing_subscriber;
109

@@ -66,7 +65,7 @@ impl VectorDb for VectorDBService {
6665

6766
let point = point_opt.unwrap();
6867

69-
// TODO: handle payload properly once defined
68+
// TODO: handle payload properly once we decide to define it
7069
Ok(Response::new(Point {
7170
id: Some(PointId { id: point.id }),
7271
vector: Some(DenseVector {
@@ -82,6 +81,8 @@ impl VectorDb for VectorDBService {
8281
) -> Result<Response<SearchResponse>, Status> {
8382
let search_request = request.into_inner();
8483

84+
//TODO: distance function panics if different dimensions; no dimension enforcement
85+
8586
// extract request arguments
8687
let query_vect = search_request.query_vector.unwrap();
8788
let similarity = SIMILARITY_PROTOBUF_MAP[search_request.similarity as usize];
@@ -106,9 +107,10 @@ impl VectorDb for VectorDBService {
106107
async fn delete_point(&self, request: Request<PointId>) -> Result<Response<()>, Status> {
107108
let point_id = request.into_inner().id;
108109

109-
if self.vector_db.delete(point_id).is_err() {
110-
return Err(Status::not_found("Point not found"));
111-
};
110+
// TODO: delete call needs to return a boolean indicating if point is present or not
111+
// if self.vector_db.delete(point_id)? {
112+
// return Err(Status::not_found("Point not found"));
113+
// };
112114

113115
Ok(Response::new(()))
114116
}
@@ -117,7 +119,7 @@ impl VectorDb for VectorDBService {
117119
#[tokio::main]
118120
async fn main() -> Result<(), Box<dyn std::error::Error>> {
119121
tracing_subscriber::fmt::init();
120-
let config = config::GRPCServerConfig::load_config()
122+
let config = GRPCServerConfig::load_config()
121123
.inspect_err(|err| panic!("Failed to load config: {}", err))
122124
.unwrap();
123125

@@ -135,8 +137,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
135137
config.addr.to_string()
136138
);
137139

138-
Server::builder()
139-
.layer(InterceptorLayer::new(interceptors::logging_interceptor))
140+
let auth_interceptor = interceptors::AuthInterceptor::new(config.root_password);
141+
142+
let _ = Server::builder()
143+
.layer(InterceptorLayer::new(logging_interceptor))
144+
.layer(InterceptorLayer::new(auth_interceptor))
140145
.add_service(VectorDbServer::new(vector_db_service))
141146
.serve(config.addr)
142147
.await

0 commit comments

Comments
 (0)