diff --git a/Cargo.lock b/Cargo.lock index ff31d5a..a405d35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "atty" version = "0.2.14" @@ -71,9 +86,24 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "byteorder" @@ -102,6 +132,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "clang-sys" version = "1.8.2" @@ -137,6 +181,21 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -146,6 +205,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "either" version = "1.12.0" @@ -194,6 +274,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -226,6 +316,21 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -241,6 +346,29 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -267,6 +395,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -338,6 +475,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -436,7 +582,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -511,6 +657,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -551,9 +708,14 @@ dependencies = [ name = "sqlite-rembed" version = "0.0.1-alpha.9" dependencies = [ + "chrono", + "hex", + "hmac", "serde_json", + "sha2", "sqlite-loadable", "ureq", + "url", "zerocopy", ] @@ -631,6 +793,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -679,21 +847,81 @@ dependencies = [ [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + [[package]] name = "webpki-roots" version = "0.26.1" @@ -746,6 +974,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 5d0bacb..f3a94b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,11 @@ serde_json = "1.0.117" sqlite-loadable = "0.0.6-alpha.6" ureq = {version="2.9.7", features=["json"]} zerocopy = "0.7.34" +chrono = "0.4.38" +hex = "0.4.3" +hmac = "0.12.1" +sha2 = "0.10.8" +url = "2.5.2" [lib] crate-type=["cdylib", "staticlib", "lib"] diff --git a/README.md b/README.md index d59a4fc..24f16be 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # `sqlite-rembed` -A SQLite extension for generating text embeddings from remote APIs (OpenAI, Nomic, Cohere, llamafile, Ollama, etc.). A sister project to [`sqlite-vec`](https://github.com/asg017/sqlite-vec) and [`sqlite-lembed`](https://github.com/asg017/sqlite-lembed). A work-in-progress! +A SQLite extension for generating text embeddings from remote APIs (OpenAI, Nomic, Cohere, llamafile, Ollama, Amazon Bedrock, etc.). A sister project to [`sqlite-vec`](https://github.com/asg017/sqlite-vec) and [`sqlite-lembed`](https://github.com/asg017/sqlite-lembed). A work-in-progress! ## Usage @@ -31,8 +31,9 @@ Other pre-defined clients include: | `mixedbread` | [MixedBread](https://www.mixedbread.ai/api-reference#quick-start-guide) | `https://api.mixedbread.ai/v1/embeddings/` | `MIXEDBREAD_API_KEY` | | `llamafile` | [llamafile](https://github.com/Mozilla-Ocho/llamafile) | `http://localhost:8080/embedding` | None | | `ollama` | [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings) | `http://localhost:11434/api/embeddings` | None | +| `bedrock` | [Amazon Bedrock](https://aws.amazon.com/bedrock/) | `https://bedrock-runtime.REGION.amazonaws.com` | Use [temporary AWS Credentials](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp.html) | -Different client options can be specified with `remebed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use: +Different client options can be specified with `rembed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use: ```sql INSERT INTO temp.rembed_clients(name, options) VALUES diff --git a/src/clients.rs b/src/clients.rs index 5f83b9a..ab30aa1 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,5 +1,10 @@ use sqlite_loadable::{Error, Result}; +use sha2::{Sha256, Digest}; +use hmac::{Hmac, Mac}; + +type HmacSha256 = Hmac; + pub(crate) fn try_env_var(key: &str) -> Result { std::env::var(key) .map_err(|_| Error::new_message(format!("{} environment variable not define. Alternatively, pass in an API key with rembed_client_options", DEFAULT_OPENAI_API_KEY_ENV))) @@ -504,6 +509,319 @@ impl LlamafileClient { } } +/* AWS SigV4 */ + +/// Computes HMAC on a message by using the SHA256 algorithm with the provided signing key. +pub(crate) fn sign(key: &[u8], msg: &[u8]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).expect("Error when signing message"); + mac.update(msg); + let result = mac.finalize(); + result.into_bytes().to_vec() +} +/// Derives a signing key by performing a succession of keyed hash operations (HMAC operations) +/// on the request date, Region, and service, with the AWS secret access key as the key for the +/// initial hashing operation. +pub(crate) fn derive_signing_key(key: &str, date: &str, region: &str, service: &str) -> Vec { + let k_date = sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); + let k_region = sign(&k_date, region.as_bytes()); + let k_service = sign(&k_region, service.as_bytes()); + sign(&k_service, "aws4_request".as_bytes()) +} + +/// Calculates the signature by performing a keyed hash operation on the string to sign. +pub(crate) fn calculate_signature(signing_key: &[u8], string_to_sign: &str) -> String { + let signature = sign(signing_key, string_to_sign.as_bytes()); + hex::encode(signature) +} + +/// Arranges the contents of the request (host, action, headers, &c.) into a standard canonical format. +pub(crate) fn create_canonical_request(http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { + let canonical_headers = canonical_headers.join("\n"); + let signed_headers = signed_headers.join(";"); + let mut hasher = Sha256::new(); + hasher.update(payload.as_bytes()); + let hashed_payload = hasher.finalize(); + format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") +} + +/// Creates a string to sign with the canonical request and extra information such as the algorithm, +/// request date, credential scope, and the hash of the canonical request. +pub(crate) fn create_string_to_sign(algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(canonical_request.as_bytes()); + let canonical_request = hasher.finalize(); + format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") +} + +/// Creates the Authorization header for the request. +pub(crate) fn create_authorization_header(algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { + let signed_headers = signed_headers.join(";"); + format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") +} + +/* Amazon Bedrock */ + +#[derive(Clone)] +pub struct AmazonBedrockClient { + model_id: String, + region: String, + aws_access_key_id: String, + aws_secret_access_key: String, + aws_session_token: String +} +const HASH_ALGORITHM: &str = "AWS4-HMAC-SHA256"; +const DEFAULT_AWS_REGION: &str = "us-east-1"; + +fn merge(a: &mut serde_json::Value, b: serde_json::Value) { + match (a, b) { + (a @ &mut serde_json::Value::Object(_), serde_json::Value::Object(b)) => { + let a = a.as_object_mut().unwrap(); + for (k, v) in b { + merge(a.entry(k).or_insert(serde_json::Value::Null), v); + } + } + (a, b) => *a = b, + } +} + +impl AmazonBedrockClient { + pub fn new>(model_id: S, region: Option, aws_access_key_id: Option, aws_secret_access_key: Option, aws_session_token: Option) -> Result { + Ok(Self { + model_id: model_id.into(), + region: region.unwrap_or(DEFAULT_AWS_REGION.to_owned()), + aws_access_key_id: aws_access_key_id.unwrap_or( + std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_default() + ), + aws_secret_access_key: aws_secret_access_key.unwrap_or( + std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default() + ), + aws_session_token: aws_session_token.unwrap_or( + std::env::var("AWS_SESSION_TOKEN").unwrap_or_default() + ), + }) + } + + pub fn infer_single(&self, input: &str, inference_options: Option) -> Result> { + + // Step 0a: extract model provider + + let model_provider = self.model_id + .split('.') + .next() + .expect("Error getting model provider"); + + // Step 0b: create payload + + let mut body = match model_provider { + "amazon" => ureq::json!({ + "inputText": input.to_owned(), + }), + "cohere" => ureq::json!({ + "texts": [ + input.to_owned() + ], + "input_type": "search_document" + }), + _ => ureq::json!({}) + }; + + let inference_options = match inference_options { + Some(v) => v, + None => ureq::json!({}) + }; + + merge(&mut body, inference_options); + + // Step 0c: get date and time + + let current_time = chrono::Utc::now(); + let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); + let amazon_date = current_time.format("%Y%m%d").to_string(); + + // Step 1: create a canonical request + + // Fix: escape 'bad' characters (*, /, :) in model_id + // https://github.com/curl/curl/issues/11794 + let model_id: String = url::form_urlencoded::byte_serialize(self.model_id.as_bytes()).collect(); + let escaped_model_id: String = url::form_urlencoded::byte_serialize(model_id.as_bytes()).collect(); + + let canonical_uri = format!("/model/{}/invoke", escaped_model_id); + let canonical_query_string = ""; + let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", model_id); + + let mut signed_headers = vec![ + "host", + "x-amz-date" + ]; + + if !self.aws_session_token.is_empty() { + signed_headers.push("x-amz-security-token"); + } + + let mut canonical_headers = vec![ + format!("host:{service_endpoint}"), + format!("x-amz-date:{amazon_time}") + ]; + + if !self.aws_session_token.is_empty() { + canonical_headers.push( + format!("x-amz-security-token:{}", self.aws_session_token) + ); + } + + let canonical_request = create_canonical_request( + "POST", + &canonical_uri, + canonical_query_string, + &canonical_headers, + &signed_headers, + &body.to_string() + ); + + // Step 2: create string to sign + + let algorithm = HASH_ALGORITHM; + let service = "bedrock"; + let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); + + let string_to_sign = create_string_to_sign( + algorithm, + &amazon_time, + &credential_scope, + &canonical_request + ); + + // Step 3: calculate signature + + let signing_key = derive_signing_key( + &self.aws_secret_access_key, + &amazon_date, + &self.region, + service + ); + + let signature = calculate_signature( + &signing_key, + &string_to_sign + ); + + // Step 4: add the signature to the request + + let authorization = create_authorization_header( + HASH_ALGORITHM, + &self.aws_access_key_id, + &credential_scope, + &signed_headers, + &signature + ); + + // Step 5: send the request + + let request = ureq::post(&endpoint) + .set("Accept", "application/json") + .set("X-Amz-Date", &amazon_time) + .set("Authorization", &authorization); + + let request = if self.aws_session_token.is_empty() { + request + } else { + request.set("X-Amz-Security-Token", &self.aws_session_token) + }; + + let response = request + .send_bytes( + body.to_string().as_bytes() + ) + .map_err( + |error| + Error::new_message( + format!("Error sending HTTP request: {error}") + ) + )? + .into_string() + .map_err( + |error| + Error::new_message( + format!("Error parsing HTTP response: {error}") + ) + )?; + + let data: serde_json::Value = serde_json::from_str(&response).unwrap(); + + AmazonBedrockClient::parse_single_response(self, &data) + } + + pub fn parse_single_response(&self, value: &serde_json::Value) -> Result> { + + let model_provider = self.model_id + .split('.') + .next() + .expect("Error getting model provider"); + + if model_provider == "amazon" { + value + .get("embedding") + .ok_or_else(|| Error::new_message("expected 'embedding' key in response body")) + .and_then(|v| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embedding' path to be an array") + }) + }) + .and_then(|arr| { + arr.iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message( + "expected 'embedding' array to contain floats", + ) + }) + .map(|f| f as f32) + }) + .collect() + }) + } else if model_provider == "cohere" { + value + .get("embeddings") + .ok_or_else(|| Error::new_message("expected 'embeddings' key in response body")) + .and_then(|v: &serde_json::Value| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embeddings' path to be an array") + }) + }) + .and_then(|v| { + v.first() + .ok_or_else(|| Error::new_message("expected 'embeddings.0' path in response body")) + }) + .and_then(|v| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embeddings.0' path to be an array") + }) + }) + .and_then(|arr| { + arr.iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message( + "expected 'embeddings.0' array to contain floats", + ) + }) + .map(|f| f as f32) + }) + .collect() + }) + } else { + Err( + Error::new_message( + format!("Model provider '{model_provider}' is not supported!") + ) + ) + } + } +} + #[derive(Clone)] pub enum Client { OpenAI(OpenAiClient), @@ -513,4 +831,5 @@ pub enum Client { Llamafile(LlamafileClient), Jina(JinaClient), Mixedbread(MixedbreadClient), + AmazonBedrock(AmazonBedrockClient), } diff --git a/src/clients_vtab.rs b/src/clients_vtab.rs index 101c95c..d38ea00 100644 --- a/src/clients_vtab.rs +++ b/src/clients_vtab.rs @@ -10,7 +10,7 @@ use std::{cell::RefCell, collections::HashMap, marker::PhantomData, mem, os::raw use crate::clients::MixedbreadClient; use crate::{ clients::{ - Client, CohereClient, JinaClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, + Client, CohereClient, JinaClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, AmazonBedrockClient, }, CLIENT_OPTIONS_POINTER_NAME, }; @@ -99,6 +99,7 @@ impl<'vtab> VTabWriteable<'vtab> for ClientsTable { "cohere" => Client::Cohere(CohereClient::new(name, None, None)?), "ollama" => Client::Ollama(OllamaClient::new(name, None)), "llamafile" => Client::Llamafile(LlamafileClient::new(None)), + "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new(name, None, None, None, None)?), text => { return Err(Error::new_message(format!( "'{text}' is not a valid rembed client." diff --git a/src/lib.rs b/src/lib.rs index 1924525..7d2265b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; -use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient}; +use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, AmazonBedrockClient}; use clients_vtab::ClientsTable; use sqlite_loadable::{ api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex, @@ -90,7 +90,18 @@ pub fn rembed_client_options( .ok_or_else(|| Error::new_message("'model' option is required"))?, options.get("url").cloned(), )), - "llamafile" => Client::Llamafile(LlamafileClient::new(options.get("url").cloned())), + "llamafile" => Client::Llamafile(LlamafileClient::new( + options.get("url").cloned()) + ), + "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new( + options + .get("model_id") + .ok_or_else(|| Error::new_message("'model_id' option is required"))?, + options.get("region").cloned(), + options.get("aws_access_key_id").cloned(), + options.get("aws_secret_access_key").cloned(), + options.get("aws_session_token").cloned(), + )?), format => return Err(Error::new_message(format!("Unknown format '{format}'"))), }; @@ -118,6 +129,10 @@ pub fn rembed( Client::Mixedbread(client) => client.infer_single(input)?, Client::Ollama(client) => client.infer_single(input)?, Client::Llamafile(client) => client.infer_single(input)?, + Client::AmazonBedrock(client) => { + let inference_options = values.get(2).and_then(|v| api::value_json(v).ok()); + client.infer_single(input, inference_options)? + } Client::Nomic(client) => { let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); client.infer_single(input, input_type)?