diff --git a/Cargo.lock b/Cargo.lock index 1b3e627..4b28e53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21,6 +21,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "bincode" version = "1.3.3" @@ -181,6 +187,7 @@ name = "index" version = "0.1.0" dependencies = [ "defs", + "ordered-float", ] [[package]] @@ -295,12 +302,30 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "ordered-float" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01" +dependencies = [ + "num-traits", +] + [[package]] name = "peeking_take_while" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 670d962..1ecb74c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "crates/server", ] + # You can define shared dependencies for all crates here [workspace.dependencies] # tokio = { version = "1.37.0", features = ["full"] } diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 1079c5e..ff15895 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -6,4 +6,5 @@ pub enum DbError { DeserializationError, IndexError(String), LockError, + IndexInitError, //TODO: Change this } diff --git a/crates/index/Cargo.toml b/crates/index/Cargo.toml index b54e3c8..c81c18f 100644 --- a/crates/index/Cargo.toml +++ b/crates/index/Cargo.toml @@ -6,5 +6,6 @@ version = "0.1.0" edition = "2021" [dependencies] +ordered-float = "5.0.0" defs = { path = "../defs" } diff --git a/crates/index/src/flat.rs b/crates/index/src/flat.rs index 53dc361..682a4d3 100644 --- a/crates/index/src/flat.rs +++ b/crates/index/src/flat.rs @@ -47,7 +47,7 @@ impl VectorIndex for FlatIndex { .index .iter() .map(|point| DistanceOrderedVector { - distance: distance(point.vector.clone(), query_vector.clone(), similarity), + distance: distance(&point.vector, &query_vector, similarity), query_vector: &query_vector, point_id: Some(point.id), }) diff --git a/crates/index/src/kd_tree.rs b/crates/index/src/kd_tree.rs index 444f578..310ee4f 100644 --- a/crates/index/src/kd_tree.rs +++ b/crates/index/src/kd_tree.rs @@ -1,252 +1,362 @@ -use std::cmp::Ordering; -use std::cmp::Ordering::Less; - -use serde_derive::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize)] -pub struct KDTreeInternals { - pub kd_tree_allow_update: bool, - pub current_number_of_kd_tree_nodes: usize, - pub rebuild_threshold: f32, - pub previous_tree_size: usize, - pub rebuild_counter: usize, +use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashMap}, + vec, +}; + +use crate::{distance, VectorIndex}; + +pub struct KDTree { + dim: usize, + root: Option>, + // An in memory point map for lookup during delete + point_map: HashMap, } -#[derive(Serialize, Deserialize)] +// the node which will be the part of the KD Tree pub struct KDTreeNode { - pub left: Option>, - pub right: Option>, - pub key: String, - pub vector: Vec, - pub dim: usize, + indexed_vector: IndexedVector, + split_dim: usize, + left: Option>, + right: Option>, + is_deleted: bool, } -impl KDTreeNode { - // Add the logic here to create a new db and insert the tree into the database - fn new(data: (String, Vec), dim: usize) -> KDTreeNode { - KDTreeNode { - left: None, - right: None, - key: data.0, - vector: data.1, - dim, - } +#[derive(Debug, Clone, PartialEq)] +struct Neighbor { + id: PointId, + distance: f32, +} + +impl Eq for Neighbor {} + +// Custom Ord implementation for the max-heap +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + self.distance + .partial_cmp(&other.distance) + .unwrap_or(Ordering::Equal) } } -pub struct KDTree { - pub _root: Option>, - pub _internals: KDTreeInternals, - pub is_debug_run: bool, - pub dim: usize, +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } } impl KDTree { - // Create an empty tree with default values - pub fn new() -> KDTree { + pub fn mock() { + //here is the mock code + } + + // Build an empty index with no points + pub fn build_empty(dim: usize) -> Self { KDTree { - _root: None, - _internals: KDTreeInternals { - kd_tree_allow_update: true, - current_number_of_kd_tree_nodes: 0, - rebuild_threshold: 2.0f32, - previous_tree_size: 0, - rebuild_counter: 0, - }, - is_debug_run: true, - dim: 0, + dim, + root: None, + point_map: HashMap::new(), } } - // Add a node - // If the dimension of the tree is zero, then it becomes equal to the input data - pub fn add_node(&mut self, data: (String, Vec), depth: usize) { - if self._root.is_none() { - self.dim = data.1.len(); - self._root = Some(Box::new(KDTreeNode::new(data, 0))); - self._internals.current_number_of_kd_tree_nodes += 1; - return; - } + // Builds the vector index from provided vectors, there should atleast be single vector for dim calculation + pub fn build(mut vectors: Vec) -> Result { + if vectors.is_empty() { + Err(DbError::IndexInitError) + } else { + let dim = vectors[0].vector.len(); - assert_eq!(self.dim, data.1.len()); + let mut point_map = HashMap::with_capacity(vectors.len()); + for iv in vectors.iter() { + point_map.insert(iv.id, iv.vector.clone()); + } + let root_node = Self::build_recursive(&mut vectors, 0, dim); + Ok(KDTree { + dim, + root: Some(root_node), + point_map, + }) + } + } - if !self._internals.kd_tree_allow_update { - println!("KDTree is locked for rebuild"); - return; + // Builds the tree recursively with given vectors and returns the pointer of the root node + pub fn build_recursive( + vectors: &mut [IndexedVector], + depth: usize, + dim: usize, + ) -> Box { + if vectors.is_empty() { + panic!("Cannot build from an empty slice recursively"); } - if self._internals.previous_tree_size != 0 { - let current_ratio: f32 = self._internals.current_number_of_kd_tree_nodes as f32 - / self._internals.previous_tree_size as f32; - if current_ratio > self._internals.rebuild_threshold { - self._internals.previous_tree_size = - self._internals.current_number_of_kd_tree_nodes; - self.rebuild(); - } + let axis = depth % dim; + let mid_idx = vectors.len() / 2; + + vectors.select_nth_unstable_by(mid_idx, |a, b| { + let a_at_axis = a.vector[axis]; + let b_at_axis = b.vector[axis]; + a_at_axis.partial_cmp(&b_at_axis).unwrap_or(Ordering::Equal) + }); + + // Using swap so that we don't need to clone the whole vector + let mut median_vec = IndexedVector { + id: 0, + vector: vec![], + }; // dummy + std::mem::swap(&mut vectors[mid_idx], &mut median_vec); + + let (left_points, right_points_with_median) = vectors.split_at_mut(mid_idx); + let right_points = &mut right_points_with_median[1..]; // Exclude the swapped-out median + + let left = if left_points.is_empty() { + None } else { - self._internals.previous_tree_size = self._internals.current_number_of_kd_tree_nodes; + Some(Self::build_recursive(left_points, depth + 1, dim)) + }; + + let right = if right_points.is_empty() { + None + } else { + Some(Self::build_recursive(right_points, depth + 1, dim)) + }; + + Box::new(KDTreeNode { + indexed_vector: median_vec, + split_dim: axis, + left, + right, + is_deleted: false, + }) + } + + pub fn insert_point(&mut self, new_vector: IndexedVector) { + // use a traverse function to get the final leaf where this belongs + if self.root.is_none() { + self.root = Some(Box::new(KDTreeNode { + indexed_vector: new_vector, + split_dim: 0, + left: None, + right: None, + is_deleted: false, + })); + return; } - self._internals.current_number_of_kd_tree_nodes += 1; - - let mut current_node = self._root.as_deref_mut().unwrap(); - let mut current_depth = depth; - loop { - let current_dimension = current_depth % self.dim; - if data.1[current_dimension] < current_node.vector[current_dimension] { - if current_node.left.is_none() { - current_node.left = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.left.as_deref_mut().unwrap(); - current_depth += 1; - } + let mut current_link = &mut self.root; + let mut depth = 0; + let dim = self.dim; + + while let Some(ref mut node_box) = current_link { + let axis = depth % dim; + let current_node = node_box.as_mut(); + + let va = new_vector.vector[axis]; + let vb = current_node.indexed_vector.vector[axis]; + + if va <= vb { + current_link = &mut current_node.left; } else { - if current_node.right.is_none() { - current_node.right = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.right.as_deref_mut().unwrap(); - current_depth += 1; - } + current_link = &mut current_node.right; } + depth += 1; } + + // Assign the new node to current link which is &mut Option> + let axis = depth % dim; + *current_link = Some(Box::new(KDTreeNode { + indexed_vector: new_vector, + split_dim: axis, + left: None, + right: None, + is_deleted: false, + })) } - // rebuild tree - fn rebuild(&mut self) { - self._internals.kd_tree_allow_update = false; - self._internals.rebuild_counter += 1; - if self.is_debug_run { - println!( - "Rebuilding tree..., Rebuild counter: {:?}", - self._internals.rebuild_counter + // Deletes the point by first finding the corresponding node using DFS and then deleting + // Returns true if point found and deleted, else false + // First make a lookup of vector from map, then traverse the tree to obtain the point and mark it as deleted + pub fn delete_point(&mut self, point_id: PointId) -> bool { + if let Some(vector_to_delete) = self.point_map.get(&point_id) { + let found_and_deleted = Self::find_and_mark_recursive( + &mut self.root, + vector_to_delete, + point_id, + 0, + self.dim, ); + + if found_and_deleted { + self.point_map.remove(&point_id); + } + + return found_and_deleted; } - let mut points = Vec::into_boxed_slice(self.traversal(0)); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; + false } - // traversal - pub fn traversal(&self, k_value: usize) -> Vec<(String, Vec)> { - let mut result: Vec<(String, Vec)> = Vec::new(); - inorder_traversal_helper(self._root.as_deref(), &mut result, k_value); - result - } + // Recursively finds and marks a node as deleted, + fn find_and_mark_recursive( + node_opt: &mut Option>, + target_vector: &DenseVector, + target_id: PointId, + depth: usize, + dim: usize, + ) -> bool { + if let Some(node) = node_opt { + if node.indexed_vector.id == target_id { + node.is_deleted = true; + return true; + } - // delete a node - pub fn delete_node(&mut self, data: String) { - self._internals.kd_tree_allow_update = false; - let mut points = self.traversal(0); - let index = points.iter().position(|x| *x.0 == data).unwrap(); - points.remove(index); - let mut points = Vec::into_boxed_slice(points); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; - } + let axis = depth % dim; + let target_val = target_vector[axis]; + let node_val = node.indexed_vector.vector[axis]; - // print data for debug - pub fn print_tree_for_debug(&self) { - let iterated: Vec<(String, Vec)> = self.traversal(0); - for iter in iterated { - println!("{}", iter.0); + if target_val < node_val { + Self::find_and_mark_recursive( + &mut node.left, + target_vector, + target_id, + depth + 1, + dim, + ) + } else if target_val > node_val { + Self::find_and_mark_recursive( + &mut node.right, + target_vector, + target_id, + depth + 1, + dim, + ) + } else { + // Need to check both right and left nodes in this case + let left_found = Self::find_and_mark_recursive( + &mut node.left, + target_vector, + target_id, + depth + 1, + dim, + ); + let right_found = Self::find_and_mark_recursive( + &mut node.right, + target_vector, + target_id, + depth + 1, + dim, + ); + left_found || right_found + } + } else { + false } } - // different methods of knn -} + pub fn search_top_k( + &self, + query_vector: DenseVector, + k: usize, + dist_type: Similarity, + ) -> Vec<(PointId, f32)> { + //Searches for top k closest vectors according to specified metric -// Traversal helper function -fn inorder_traversal_helper( - node: Option<&KDTreeNode>, - result: &mut Vec<(String, Vec)>, - k_value: usize, -) -> Option { - if node.is_none() { - return None; - } - if k_value != 0 && k_value <= result.len() { - return None; - } - let current_node = node.unwrap(); - inorder_traversal_helper(current_node.to_owned().left.as_deref(), result, k_value); - result.push((current_node.key.clone(), current_node.vector.clone())); - inorder_traversal_helper(current_node.to_owned().right.as_deref(), result, k_value); + if self.root.is_none() || k == 0 { + return Vec::new(); + } - Some(true) -} + let mut best_neighbours = BinaryHeap::with_capacity(k); -// Rebuild tree helper functions -fn create_tree_helper(points: &mut [(String, Vec)], dim: usize) -> KDTreeNode { - let points_len = points.len(); - if points_len == 1 { - return KDTreeNode { - key: points[0].0.clone(), - vector: points[0].1.clone(), - left: None, - right: None, - dim, - }; + self.search_recursive( + &self.root, + &query_vector, + k, + &mut best_neighbours, + 0, + dist_type, + ); + + best_neighbours + .into_sorted_vec() + .iter() + .map(|neighbor| (neighbor.id, neighbor.distance)) + .collect() } - // Split around the median - let pivot = quickselect_by(points, points_len / 2, &|a, b| { - a.1[dim].partial_cmp(&b.1[dim]).unwrap() - }); - - let left = Some(Box::new(create_tree_helper( - &mut points[0..points_len / 2], - (dim + 1) % pivot.1.len(), - ))); - let right = if points.len() >= 3 { - Some(Box::new(create_tree_helper( - &mut points[points_len / 2 + 1..points_len], - (dim + 1) % pivot.1.len(), - ))) - } else { - None - }; - - KDTreeNode { - key: pivot.0, - vector: pivot.1, - left, - right, - dim, + fn search_recursive( + &self, + node_opt: &Option>, + query_vector: &DenseVector, + k: usize, + heap: &mut BinaryHeap, + depth: usize, + dist_type: Similarity, + ) { + // Base case is that we hit a leaf node don't do anything + if let Some(node) = node_opt { + let axis = depth % self.dim; + + let (near_side, far_side) = if query_vector[axis] <= node.indexed_vector.vector[axis] { + (&node.left, &node.right) + } else { + (&node.right, &node.left) + }; + + // Recurse on near side first + self.search_recursive(&near_side, query_vector, k, heap, depth + 1, dist_type); + + // Process the current node + if !node.is_deleted { + //TODO: Use square distance in distance, why is there overhead of square + let distance = distance(query_vector, &node.indexed_vector.vector, dist_type); + if heap.len() < k { + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } else if distance < heap.peek().unwrap().distance { + heap.pop(); + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } + } + + // Pruning on the farther side to check if there are better candidates + //TODO: Change this when implementing square distance + let dist_to_plane = match dist_type { + Similarity::Euclidean => query_vector[axis] - node.indexed_vector.vector[axis], + Similarity::Manhattan => 1.0, + _ => unreachable!(), + }; + + if heap.len() < k || dist_to_plane < heap.peek().unwrap().distance { + self.search_recursive(far_side, query_vector, k, heap, depth + 1, dist_type); + } + } } } -fn quickselect_by(arr: &mut [T], position: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> T -where - T: Clone, -{ - let mut pivot_index = 0; - // Need to wrap in another closure or we get ownership complaints. - // Tried using an unboxed closure to get around this but couldn't get it to work. - pivot_index = partition_by(arr, pivot_index, &|a: &T, b: &T| cmp(a, b)); - let array_len = arr.len(); - match position.cmp(&pivot_index) { - Ordering::Equal => arr[position].clone(), - Ordering::Less => quickselect_by(&mut arr[0..pivot_index], position, cmp), - Ordering::Greater => quickselect_by( - &mut arr[pivot_index + 1..array_len], - position - pivot_index - 1, - cmp, - ), +impl VectorIndex for KDTree { + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { + self.insert_point(vector); + Ok(()) + } + + fn delete(&mut self, point_id: PointId) -> Result { + Ok(self.delete_point(point_id)) } -} -fn partition_by(arr: &mut [T], pivot_index: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> usize { - let array_len = arr.len(); - arr.swap(pivot_index, array_len - 1); - let mut store_index = 0; - for i in 0..array_len - 1 { - if cmp(&arr[i], &arr[array_len - 1]) == Less { - arr.swap(i, store_index); - store_index += 1; + fn search( + &self, + query_vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result, DbError> { + if matches!(similarity, Similarity::Cosine | Similarity::Hamming) { + panic!("Cosine and hamming are not suitable similariyt metric when using a KDTree") } + + Ok(vec![]) } - arr.swap(array_len - 1, store_index); - store_index } diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index ef93755..eac20f0 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -1,6 +1,7 @@ use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; pub mod flat; +pub mod kd_tree; pub trait VectorIndex { fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>; @@ -19,7 +20,7 @@ pub trait VectorIndex { } /// Distance function to get the distance between two vectors (taken from old version) -pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { +pub fn distance(a: &DenseVector, b: &DenseVector, dist_type: Similarity) -> f32 { assert_eq!(a.len(), b.len()); match dist_type { Similarity::Euclidean => {