Skip to content

Commit dedb7df

Browse files
authored
Merge pull request #29 from sdslabs/adesh-refactor
2 parents a2083fb + caacef0 commit dedb7df

File tree

4 files changed

+324
-12
lines changed

4 files changed

+324
-12
lines changed

crates/core/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ pub enum DbError {
44
StorageError(String),
55
SerializationError(String),
66
DeserializationError,
7+
IndexError(String),
78
}

crates/core/src/types.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ pub struct Point {
2727
pub payload: Option<Payload>,
2828
}
2929

30+
/// Struct which will be stored in the vector index
31+
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
32+
pub struct IndexedVector {
33+
pub id: PointId,
34+
pub vector: DenseVector,
35+
}
36+
37+
#[derive(Copy, Clone)]
38+
pub enum Similarity {
39+
Euclidean,
40+
Manhattan,
41+
Hamming,
42+
Cosine,
43+
}
44+
3045
// Query Vector. Basically the type of query results that can be generated. Not implementing this but referencing here for furture reference
3146
// #[derive(Debug, Clone)]
3247
// pub enum QueryVector {

crates/index/src/flat.rs

Lines changed: 252 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,257 @@
1-
use core::DenseVector;
1+
use core::{DbError, DenseVector, IndexedVector, PointId, Similarity};
22

3-
struct FlatIndex {
4-
index: Vec<DenseVector>
3+
use crate::{distance, VectorIndex};
4+
5+
pub struct FlatIndex {
6+
index: Vec<IndexedVector>,
7+
}
8+
9+
impl FlatIndex {
10+
pub fn new() -> Self {
11+
Self { index: Vec::new() }
12+
}
13+
14+
pub fn build(vectors: Vec<IndexedVector>) -> Self {
15+
FlatIndex { index: vectors }
16+
}
17+
}
18+
19+
impl Default for FlatIndex {
20+
fn default() -> Self {
21+
Self::new()
22+
}
523
}
624

725
impl VectorIndex for FlatIndex {
8-
fn insert() {
9-
26+
fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> {
27+
self.index.push(vector);
28+
Ok(())
1029
}
11-
}
30+
31+
fn delete(&mut self, point_id: PointId) -> Result<bool, DbError> {
32+
if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) {
33+
self.index.remove(pos);
34+
Ok(true)
35+
} else {
36+
Ok(false)
37+
}
38+
}
39+
40+
fn search(
41+
&self,
42+
query_vector: DenseVector,
43+
similarity: Similarity,
44+
k: usize,
45+
) -> Result<Vec<PointId>, DbError> {
46+
let mut scores = self
47+
.index
48+
.iter()
49+
.map(|point| {
50+
(
51+
point.id,
52+
distance(point.vector.clone(), query_vector.clone(), similarity),
53+
)
54+
})
55+
.collect::<Vec<_>>();
56+
57+
// Sorting logic according to type of metric used
58+
match similarity {
59+
Similarity::Euclidean | Similarity::Manhattan | Similarity::Hamming => {
60+
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
61+
}
62+
Similarity::Cosine => {
63+
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
64+
}
65+
}
66+
67+
Ok(scores
68+
.into_iter()
69+
.take(k)
70+
.map(|(id, _)| id)
71+
.collect::<Vec<_>>())
72+
}
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
79+
#[test]
80+
fn test_flat_index_new() {
81+
let index = FlatIndex::new();
82+
assert_eq!(index.index.len(), 0);
83+
}
84+
85+
#[test]
86+
fn test_flat_index_build() {
87+
let vectors = vec![
88+
IndexedVector {
89+
id: 1,
90+
vector: vec![1.0, 2.0, 3.0],
91+
},
92+
IndexedVector {
93+
id: 2,
94+
vector: vec![4.0, 5.0, 6.0],
95+
},
96+
];
97+
let index = FlatIndex::build(vectors.clone());
98+
assert_eq!(index.index, vectors);
99+
}
100+
101+
#[test]
102+
fn test_insert() {
103+
let mut index = FlatIndex::new();
104+
let vector = IndexedVector {
105+
id: 1,
106+
vector: vec![1.0, 2.0, 3.0],
107+
};
108+
109+
assert!(index.insert(vector.clone()).is_ok());
110+
assert_eq!(index.index.len(), 1);
111+
assert_eq!(index.index[0], vector);
112+
}
113+
114+
#[test]
115+
fn test_delete_existing() {
116+
let mut index = FlatIndex::new();
117+
let vector = IndexedVector {
118+
id: 1,
119+
vector: vec![1.0, 2.0, 3.0],
120+
};
121+
index.insert(vector).unwrap();
122+
123+
let result = index.delete(1).unwrap();
124+
assert!(result);
125+
assert_eq!(index.index.len(), 0);
126+
}
127+
128+
#[test]
129+
fn test_delete_non_existing() {
130+
let mut index = FlatIndex::new();
131+
let vector = IndexedVector {
132+
id: 1,
133+
vector: vec![1.0, 2.0, 3.0],
134+
};
135+
index.insert(vector).unwrap();
136+
137+
let result = index.delete(999).unwrap();
138+
assert!(!result);
139+
assert_eq!(index.index.len(), 1);
140+
}
141+
142+
#[test]
143+
fn test_search_euclidean() {
144+
let mut index = FlatIndex::new();
145+
index
146+
.insert(IndexedVector {
147+
id: 1,
148+
vector: vec![1.0, 1.0],
149+
})
150+
.unwrap();
151+
index
152+
.insert(IndexedVector {
153+
id: 2,
154+
vector: vec![2.0, 2.0],
155+
})
156+
.unwrap();
157+
index
158+
.insert(IndexedVector {
159+
id: 3,
160+
vector: vec![10.0, 10.0],
161+
})
162+
.unwrap();
163+
164+
let results = index
165+
.search(vec![0.0, 0.0], Similarity::Euclidean, 2)
166+
.unwrap();
167+
assert_eq!(results, vec![1, 2]);
168+
}
169+
170+
#[test]
171+
fn test_search_cosine() {
172+
let mut index = FlatIndex::new();
173+
index
174+
.insert(IndexedVector {
175+
id: 1,
176+
vector: vec![1.0, 0.0],
177+
})
178+
.unwrap();
179+
index
180+
.insert(IndexedVector {
181+
id: 2,
182+
vector: vec![0.5, 0.5],
183+
})
184+
.unwrap();
185+
index
186+
.insert(IndexedVector {
187+
id: 3,
188+
vector: vec![0.0, 1.0],
189+
})
190+
.unwrap();
191+
192+
let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap();
193+
assert_eq!(results, vec![2, 1]);
194+
}
195+
196+
#[test]
197+
fn test_search_manhattan() {
198+
let mut index = FlatIndex::new();
199+
index
200+
.insert(IndexedVector {
201+
id: 1,
202+
vector: vec![1.0, 1.0],
203+
})
204+
.unwrap();
205+
index
206+
.insert(IndexedVector {
207+
id: 2,
208+
vector: vec![2.0, 2.0],
209+
})
210+
.unwrap();
211+
index
212+
.insert(IndexedVector {
213+
id: 3,
214+
vector: vec![5.0, 5.0],
215+
})
216+
.unwrap();
217+
218+
let results = index
219+
.search(vec![0.0, 0.0], Similarity::Manhattan, 2)
220+
.unwrap();
221+
assert_eq!(results, vec![1, 2]);
222+
}
223+
224+
#[test]
225+
fn test_search_hamming() {
226+
let mut index = FlatIndex::new();
227+
index
228+
.insert(IndexedVector {
229+
id: 1,
230+
vector: vec![1.0, 0.0, 1.0, 0.0],
231+
})
232+
.unwrap();
233+
index
234+
.insert(IndexedVector {
235+
id: 2,
236+
vector: vec![1.0, 0.0, 0.0, 0.0],
237+
})
238+
.unwrap();
239+
index
240+
.insert(IndexedVector {
241+
id: 3,
242+
vector: vec![0.0, 0.0, 0.0, 0.0],
243+
})
244+
.unwrap();
245+
246+
let results = index
247+
.search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2)
248+
.unwrap();
249+
assert_eq!(results, vec![2, 3]);
250+
}
251+
252+
#[test]
253+
fn test_default() {
254+
let index = FlatIndex::default();
255+
assert_eq!(index.index.len(), 0);
256+
}
257+
}

crates/index/src/lib.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,61 @@
1-
use core::{DenseVector, PointId};
2-
use std::fmt::Error;
1+
use core::{DbError, DenseVector, IndexedVector, PointId, Similarity};
2+
3+
pub mod flat;
34

45
pub trait VectorIndex {
5-
fn insert(&self, vector: DenseVector) -> Result<(), Error>;
6-
fn delete(&self, point_id: PointId) -> Result<(), Error>;
7-
fn search(&self, query_vector: DenseVector) -> Result<DenseVector, Error>;
8-
// fn build() -> Result<(), Error>; move this to impl for dyn compatibility
6+
fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>;
7+
8+
// Returns true if point id existed and is deleted, else returns false
9+
fn delete(&mut self, point_id: PointId) -> Result<bool, DbError>;
10+
11+
fn search(
12+
&self,
13+
query_vector: DenseVector,
14+
similarity: Similarity,
15+
k: usize,
16+
) -> Result<Vec<PointId>, DbError>; // Return a Vec of ids of closest vectors (length max k)
17+
18+
// fn build() -> Result<(), DbError>; move this to impl for dyn compatibility
19+
}
20+
21+
/// Distance function to get the distance between two vectors (taken from old version)
22+
pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 {
23+
assert_eq!(a.len(), b.len());
24+
match dist_type {
25+
Similarity::Euclidean => {
26+
let score: Vec<f32> = a
27+
.iter()
28+
.zip(b.iter())
29+
.map(|(&x, &y)| (x - y) * (x - y))
30+
.collect();
31+
score.iter().sum::<f32>().sqrt()
32+
}
33+
Similarity::Manhattan => {
34+
let score: Vec<f32> = a
35+
.iter()
36+
.zip(b.iter())
37+
.map(|(&x, &y)| (x - y).abs())
38+
.collect();
39+
score.iter().sum::<f32>()
40+
}
41+
Similarity::Hamming => {
42+
let score: Vec<f32> = a
43+
.iter()
44+
.zip(b.iter())
45+
.map(|(&x, &y)| (if (x - y) > 1e-8 { 1f32 } else { 0f32 }))
46+
.collect();
47+
score.iter().sum::<f32>()
48+
}
49+
Similarity::Cosine => {
50+
let p_score: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
51+
let p = p_score.iter().sum::<f32>();
52+
let q_score: Vec<f32> = a.iter().map(|&n| n * n).collect();
53+
let q = q_score.iter().sum::<f32>().sqrt();
54+
let r_score: Vec<f32> = b.iter().map(|&n| n * n).collect();
55+
let r = r_score.iter().sum::<f32>().sqrt();
56+
p / (q * r)
57+
}
58+
}
959
}
1060

1161
pub enum IndexType {

0 commit comments

Comments
 (0)