From c818dd2f418415a845300e8e871301ec475cdebf Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 03:43:31 +0100 Subject: [PATCH 01/11] Added support for Amazon Bedrock --- Cargo.lock | 242 ++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 4 + README.md | 5 +- src/clients.rs | 244 ++++++++++++++++++++++++++++++++++++++++++++ src/clients_vtab.rs | 3 +- src/lib.rs | 16 ++- 6 files changed, 506 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff31d5a..ad6dba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "atty" version = "0.2.14" @@ -71,9 +86,24 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "byteorder" @@ -102,6 +132,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "clang-sys" version = "1.8.2" @@ -137,6 +181,21 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -146,6 +205,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "either" version = "1.12.0" @@ -194,6 +274,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -226,6 +316,21 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -241,6 +346,29 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -267,6 +395,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -338,6 +475,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -436,7 +582,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -511,6 +657,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -551,7 +708,11 @@ dependencies = [ name = "sqlite-rembed" version = "0.0.1-alpha.9" dependencies = [ + "chrono", + "hex", + "hmac", "serde_json", + "sha2", "sqlite-loadable", "ureq", "zerocopy", @@ -631,6 +792,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -688,12 +855,72 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + [[package]] name = "webpki-roots" version = "0.26.1" @@ -746,6 +973,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 5d0bacb..63efe57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,10 @@ serde_json = "1.0.117" sqlite-loadable = "0.0.6-alpha.6" ureq = {version="2.9.7", features=["json"]} zerocopy = "0.7.34" +chrono = "0.4.38" +hex = "0.4.3" +hmac = "0.12.1" +sha2 = "0.10.8" [lib] crate-type=["cdylib", "staticlib", "lib"] diff --git a/README.md b/README.md index d59a4fc..24f16be 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # `sqlite-rembed` -A SQLite extension for generating text embeddings from remote APIs (OpenAI, Nomic, Cohere, llamafile, Ollama, etc.). A sister project to [`sqlite-vec`](https://github.com/asg017/sqlite-vec) and [`sqlite-lembed`](https://github.com/asg017/sqlite-lembed). A work-in-progress! +A SQLite extension for generating text embeddings from remote APIs (OpenAI, Nomic, Cohere, llamafile, Ollama, Amazon Bedrock, etc.). A sister project to [`sqlite-vec`](https://github.com/asg017/sqlite-vec) and [`sqlite-lembed`](https://github.com/asg017/sqlite-lembed). A work-in-progress! ## Usage @@ -31,8 +31,9 @@ Other pre-defined clients include: | `mixedbread` | [MixedBread](https://www.mixedbread.ai/api-reference#quick-start-guide) | `https://api.mixedbread.ai/v1/embeddings/` | `MIXEDBREAD_API_KEY` | | `llamafile` | [llamafile](https://github.com/Mozilla-Ocho/llamafile) | `http://localhost:8080/embedding` | None | | `ollama` | [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings) | `http://localhost:11434/api/embeddings` | None | +| `bedrock` | [Amazon Bedrock](https://aws.amazon.com/bedrock/) | `https://bedrock-runtime.REGION.amazonaws.com` | Use [temporary AWS Credentials](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp.html) | -Different client options can be specified with `remebed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use: +Different client options can be specified with `rembed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use: ```sql INSERT INTO temp.rembed_clients(name, options) VALUES diff --git a/src/clients.rs b/src/clients.rs index 5f83b9a..3e0f1f6 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,5 +1,10 @@ use sqlite_loadable::{Error, Result}; +use sha2::{Sha256, Digest}; +use hmac::{Hmac, Mac}; + +type HmacSha256 = Hmac; + pub(crate) fn try_env_var(key: &str) -> Result { std::env::var(key) .map_err(|_| Error::new_message(format!("{} environment variable not define. Alternatively, pass in an API key with rembed_client_options", DEFAULT_OPENAI_API_KEY_ENV))) @@ -504,6 +509,244 @@ impl LlamafileClient { } } +#[derive(Clone)] +pub struct AmazonBedrockClient { + model_id: String, + region: String, + aws_access_key_id: String, + aws_secret_access_key: String, + aws_session_token: String +} +const DEFAULT_AWS_REGION: &str = "us-east-1"; + +impl AmazonBedrockClient { + pub fn new>(model_id: S, region: Option, aws_access_key_id: Option, aws_secret_access_key: Option, aws_session_token: Option) -> Result { + Ok(Self { + model_id: model_id.into(), + region: region.unwrap_or(DEFAULT_AWS_REGION.to_owned()), + aws_access_key_id: aws_access_key_id.unwrap_or( + std::env::var("AWS_ACCESS_KEY_ID").unwrap() + ), + aws_secret_access_key: aws_secret_access_key.unwrap_or( + std::env::var("AWS_SECRET_ACCESS_KEY").unwrap() + ), + aws_session_token: aws_session_token.unwrap_or( + std::env::var("AWS_SESSION_TOKEN").unwrap_or_default() + ), + }) + } + + pub fn sign(&self, key: &[u8], msg: &[u8]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).unwrap(); + mac.update(msg); + let result = mac.finalize(); + result.into_bytes().to_vec() + } + + pub fn get_signing_key(&self, key: &str, date: &str, region: &str, service: &str) -> Vec { + let k_date = self.sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); + let k_region = self.sign(&k_date, region.as_bytes()); + let k_service = self.sign(&k_region, service.as_bytes()); + self.sign(&k_service, "aws4_request".as_bytes()) + } + + pub fn get_signature(&self, signing_key: &[u8], string_to_sign: &str) -> String { + let signature = self.sign(signing_key, string_to_sign.as_bytes()); + hex::encode(signature) + } + + pub fn get_canonical_request(&self, http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { + let canonical_headers = canonical_headers.join("\n"); + let signed_headers = signed_headers.join(";"); + let mut hasher = Sha256::new(); + hasher.update(payload.as_bytes()); + let hashed_payload = hasher.finalize(); + format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") + } + + pub fn get_string_to_sign(&self, algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(canonical_request.as_bytes()); + let canonical_request = hasher.finalize(); + format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") + } + + pub fn get_authorization_header(&self, algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { + let signed_headers = signed_headers.join(";"); + format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") + } + + pub fn infer_single(&self, input: &str) -> Result> { + // Get model provider + let model_provider = self.model_id + .split('.') + .next() + .unwrap(); + + // Create payload + let body = match model_provider { + "amazon" => ureq::json!({ + "inputText": input.to_owned(), + }), + "cohere" => ureq::json!({ + "texts": [input.to_owned()], + "input_type": "search_document", + "truncate": "NONE" + }), + _ => ureq::json!({}) + }; + + // Get date and time + let current_time = chrono::Utc::now(); + let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); + let amazon_date = current_time.format("%Y%m%d").to_string(); + + // Step 1: create a canonical request + let canonical_uri = format!("/model/{}/invoke", self.model_id); + let canonical_query_string = ""; + let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", self.model_id); + let canonical_request = self.get_canonical_request( + "POST", + &canonical_uri, + canonical_query_string, + &[ + format!("host:{service_endpoint}"), + format!("x-amz-date:{amazon_time}"), + format!("x-amz-security-token:{}", self.aws_session_token) + ], + &[ + "host", + "x-amz-date", + "x-amz-security-token" + ], + &body.to_string() + ); + + // Step 2: create string to sign + + let service = "bedrock"; + let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); + + let string_to_sign = self.get_string_to_sign( + "AWS4-HMAC-SHA256", + &amazon_time, + &credential_scope, + &canonical_request + ); + + // Step 3: calculate signature + + let signing_key = self.get_signing_key( + &self.aws_secret_access_key, + &amazon_date, + &self.region, + service + ); + + let signature = self.get_signature( + &signing_key, + &string_to_sign + ); + + // Step 4: add the signature to the request + + let authorization = self.get_authorization_header( + "AWS4-HMAC-SHA256", + &self.aws_access_key_id, + &credential_scope, + &[ + "host", + "x-amz-date", + "x-amz-security-token" + ], + &signature + ); + + // Step 5: send the request + + let response = ureq::post(&endpoint) + .set("Accept", "application/json") + .set("X-Amz-Date", &amazon_time) + .set("X-Amz-Security-Token", &self.aws_session_token) + .set("Authorization", &authorization) + .send_bytes( + body.to_string().as_bytes() + ) + .unwrap() + .into_string() + .unwrap(); + + let data: serde_json::Value = serde_json::from_str(&response).unwrap(); + + AmazonBedrockClient::parse_single_response(self, &data) + } + + pub fn parse_single_response(&self, value: &serde_json::Value) -> Result> { + + let model_provider = self.model_id.split('.').next().unwrap().to_string(); + + let output: Result>; + if model_provider == "amazon" { + output = value + .get("embedding") + .ok_or_else(|| Error::new_message("expected 'embedding' key in response body")) + .and_then(|v| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embedding' path to be an array") + }) + }) + .and_then(|arr| { + arr.iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message( + "expected 'embedding' array to contain floats", + ) + }) + .map(|f| f as f32) + }) + .collect() + }); + } else if model_provider == "cohere" { + output = value + .get("embeddings") + .ok_or_else(|| Error::new_message("expected 'embeddings' key in response body")) + .and_then(|v: &serde_json::Value| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embeddings' path to be an array") + }) + }) + .and_then(|v| { + v.get(0) + .ok_or_else(|| Error::new_message("expected 'embeddings.0' path in response body")) + }) + .and_then(|v| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embeddings.0' path to be an array") + }) + }) + .and_then(|arr| { + arr.iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message( + "expected 'embeddings.0' array to contain floats", + ) + }) + .map(|f| f as f32) + }) + .collect() + }); + } else { + todo!(); + } + output + } +} + #[derive(Clone)] pub enum Client { OpenAI(OpenAiClient), @@ -513,4 +756,5 @@ pub enum Client { Llamafile(LlamafileClient), Jina(JinaClient), Mixedbread(MixedbreadClient), + AmazonBedrock(AmazonBedrockClient), } diff --git a/src/clients_vtab.rs b/src/clients_vtab.rs index 101c95c..c9644a6 100644 --- a/src/clients_vtab.rs +++ b/src/clients_vtab.rs @@ -10,7 +10,7 @@ use std::{cell::RefCell, collections::HashMap, marker::PhantomData, mem, os::raw use crate::clients::MixedbreadClient; use crate::{ clients::{ - Client, CohereClient, JinaClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, + Client, CohereClient, JinaClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, AmazonBedrockClient, }, CLIENT_OPTIONS_POINTER_NAME, }; @@ -99,6 +99,7 @@ impl<'vtab> VTabWriteable<'vtab> for ClientsTable { "cohere" => Client::Cohere(CohereClient::new(name, None, None)?), "ollama" => Client::Ollama(OllamaClient::new(name, None)), "llamafile" => Client::Llamafile(LlamafileClient::new(None)), + "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new(name, None, None, None, None)?), text => { return Err(Error::new_message(format!( "'{text}' is not a valid rembed client." diff --git a/src/lib.rs b/src/lib.rs index 1924525..9fc4ac3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; -use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient}; +use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, AmazonBedrockClient}; use clients_vtab::ClientsTable; use sqlite_loadable::{ api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex, @@ -90,7 +90,18 @@ pub fn rembed_client_options( .ok_or_else(|| Error::new_message("'model' option is required"))?, options.get("url").cloned(), )), - "llamafile" => Client::Llamafile(LlamafileClient::new(options.get("url").cloned())), + "llamafile" => Client::Llamafile(LlamafileClient::new( + options.get("url").cloned()) + ), + "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new( + options + .get("model_id") + .ok_or_else(|| Error::new_message("'model_id' option is required"))?, + options.get("region").cloned(), + options.get("aws_access_key_id").cloned(), + options.get("aws_secret_access_key").cloned(), + options.get("aws_session_token").cloned(), + )?), format => return Err(Error::new_message(format!("Unknown format '{format}'"))), }; @@ -118,6 +129,7 @@ pub fn rembed( Client::Mixedbread(client) => client.infer_single(input)?, Client::Ollama(client) => client.infer_single(input)?, Client::Llamafile(client) => client.infer_single(input)?, + Client::AmazonBedrock(client) => client.infer_single(input)?, Client::Nomic(client) => { let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); client.infer_single(input, input_type)? From 3c4b7eb6a23ee84338eb86a8966137166f146534 Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 10:45:52 +0100 Subject: [PATCH 02/11] Refactored to allow non-temp AWS creds --- src/clients.rs | 68 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 3e0f1f6..aa7e190 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -577,13 +577,16 @@ impl AmazonBedrockClient { } pub fn infer_single(&self, input: &str) -> Result> { - // Get model provider + + // Step 0a. extract model provider + let model_provider = self.model_id .split('.') .next() .unwrap(); - // Create payload + // Step 0b. create payload + let body = match model_provider { "amazon" => ureq::json!({ "inputText": input.to_owned(), @@ -596,40 +599,56 @@ impl AmazonBedrockClient { _ => ureq::json!({}) }; - // Get date and time + // Step 0c. get date and time + let current_time = chrono::Utc::now(); let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); let amazon_date = current_time.format("%Y%m%d").to_string(); // Step 1: create a canonical request + let canonical_uri = format!("/model/{}/invoke", self.model_id); let canonical_query_string = ""; let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", self.model_id); + + let mut signed_headers = vec![ + "host", + "x-amz-date" + ]; + + if !self.aws_session_token.is_empty() { + signed_headers.push("x-amz-security-token"); + } + + let mut canonical_headers = vec![ + format!("host:{service_endpoint}"), + format!("x-amz-date:{amazon_time}") + ]; + + if !self.aws_session_token.is_empty() { + canonical_headers.push( + format!("x-amz-security-token:{}", self.aws_session_token) + ); + } + let canonical_request = self.get_canonical_request( "POST", &canonical_uri, canonical_query_string, - &[ - format!("host:{service_endpoint}"), - format!("x-amz-date:{amazon_time}"), - format!("x-amz-security-token:{}", self.aws_session_token) - ], - &[ - "host", - "x-amz-date", - "x-amz-security-token" - ], + &canonical_headers, + &signed_headers, &body.to_string() ); // Step 2: create string to sign + let algorithm = "AWS4-HMAC-SHA256"; let service = "bedrock"; let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); let string_to_sign = self.get_string_to_sign( - "AWS4-HMAC-SHA256", + algorithm, &amazon_time, &credential_scope, &canonical_request @@ -655,21 +674,24 @@ impl AmazonBedrockClient { "AWS4-HMAC-SHA256", &self.aws_access_key_id, &credential_scope, - &[ - "host", - "x-amz-date", - "x-amz-security-token" - ], + &signed_headers, &signature ); // Step 5: send the request - let response = ureq::post(&endpoint) + let request = ureq::post(&endpoint) .set("Accept", "application/json") .set("X-Amz-Date", &amazon_time) - .set("X-Amz-Security-Token", &self.aws_session_token) - .set("Authorization", &authorization) + .set("Authorization", &authorization); + + let request = if !self.aws_session_token.is_empty() { + request.clone() + } else { + request.clone().set("X-Amz-Security-Token", &self.aws_session_token) + }; + + let response = request.clone() .send_bytes( body.to_string().as_bytes() ) @@ -719,7 +741,7 @@ impl AmazonBedrockClient { }) }) .and_then(|v| { - v.get(0) + v.first() .ok_or_else(|| Error::new_message("expected 'embeddings.0' path in response body")) }) .and_then(|v| { From c6132db268384c189ee17b8b2a480e68fb0964dc Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 12:06:23 +0100 Subject: [PATCH 03/11] Added cohere options (input_type + truncate) --- src/clients.rs | 8 ++++---- src/lib.rs | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index aa7e190..b04591d 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -576,7 +576,7 @@ impl AmazonBedrockClient { format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") } - pub fn infer_single(&self, input: &str) -> Result> { + pub fn infer_single(&self, input: &str, input_type: Option<&str>, truncate: Option<&str>) -> Result> { // Step 0a. extract model provider @@ -593,8 +593,8 @@ impl AmazonBedrockClient { }), "cohere" => ureq::json!({ "texts": [input.to_owned()], - "input_type": "search_document", - "truncate": "NONE" + "input_type": input_type.unwrap_or("search_document"), + "truncate": truncate.unwrap_or("NONE") }), _ => ureq::json!({}) }; @@ -685,7 +685,7 @@ impl AmazonBedrockClient { .set("X-Amz-Date", &amazon_time) .set("Authorization", &authorization); - let request = if !self.aws_session_token.is_empty() { + let request = if self.aws_session_token.is_empty() { request.clone() } else { request.clone().set("X-Amz-Security-Token", &self.aws_session_token) diff --git a/src/lib.rs b/src/lib.rs index 9fc4ac3..b688222 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,7 +129,11 @@ pub fn rembed( Client::Mixedbread(client) => client.infer_single(input)?, Client::Ollama(client) => client.infer_single(input)?, Client::Llamafile(client) => client.infer_single(input)?, - Client::AmazonBedrock(client) => client.infer_single(input)?, + Client::AmazonBedrock(client) => { + let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); + let truncate = values.get(3).and_then(|v| api::value_text(v).ok()); + client.infer_single(input, input_type, truncate)? + } Client::Nomic(client) => { let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); client.infer_single(input, input_type)? @@ -169,6 +173,7 @@ pub fn sqlite3_rembed_init(db: *mut sqlite3) -> Result<()> { )?; define_scalar_function_with_aux(db, "rembed", 2, rembed, flags, Rc::clone(&c))?; define_scalar_function_with_aux(db, "rembed", 3, rembed, flags, Rc::clone(&c))?; + define_scalar_function_with_aux(db, "rembed", 4, rembed, flags, Rc::clone(&c))?; define_scalar_function( db, "rembed_client_options", From b0aed8724d95c97060f7c5ab70462c1425e3d30e Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 12:09:35 +0100 Subject: [PATCH 04/11] Fixed formatting --- src/clients.rs | 322 ++++++++++++++++++++++---------------------- src/clients_vtab.rs | 2 +- src/lib.rs | 10 +- 3 files changed, 167 insertions(+), 167 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index b04591d..8a56d5f 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -513,9 +513,9 @@ impl LlamafileClient { pub struct AmazonBedrockClient { model_id: String, region: String, - aws_access_key_id: String, - aws_secret_access_key: String, - aws_session_token: String + aws_access_key_id: String, + aws_secret_access_key: String, + aws_session_token: String } const DEFAULT_AWS_REGION: &str = "us-east-1"; @@ -524,57 +524,57 @@ impl AmazonBedrockClient { Ok(Self { model_id: model_id.into(), region: region.unwrap_or(DEFAULT_AWS_REGION.to_owned()), - aws_access_key_id: aws_access_key_id.unwrap_or( - std::env::var("AWS_ACCESS_KEY_ID").unwrap() - ), - aws_secret_access_key: aws_secret_access_key.unwrap_or( - std::env::var("AWS_SECRET_ACCESS_KEY").unwrap() - ), - aws_session_token: aws_session_token.unwrap_or( - std::env::var("AWS_SESSION_TOKEN").unwrap_or_default() - ), + aws_access_key_id: aws_access_key_id.unwrap_or( + std::env::var("AWS_ACCESS_KEY_ID").unwrap() + ), + aws_secret_access_key: aws_secret_access_key.unwrap_or( + std::env::var("AWS_SECRET_ACCESS_KEY").unwrap() + ), + aws_session_token: aws_session_token.unwrap_or( + std::env::var("AWS_SESSION_TOKEN").unwrap_or_default() + ), }) } - pub fn sign(&self, key: &[u8], msg: &[u8]) -> Vec { - let mut mac = HmacSha256::new_from_slice(key).unwrap(); - mac.update(msg); - let result = mac.finalize(); - result.into_bytes().to_vec() - } - - pub fn get_signing_key(&self, key: &str, date: &str, region: &str, service: &str) -> Vec { - let k_date = self.sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); - let k_region = self.sign(&k_date, region.as_bytes()); - let k_service = self.sign(&k_region, service.as_bytes()); - self.sign(&k_service, "aws4_request".as_bytes()) - } - - pub fn get_signature(&self, signing_key: &[u8], string_to_sign: &str) -> String { - let signature = self.sign(signing_key, string_to_sign.as_bytes()); - hex::encode(signature) - } - - pub fn get_canonical_request(&self, http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { - let canonical_headers = canonical_headers.join("\n"); - let signed_headers = signed_headers.join(";"); - let mut hasher = Sha256::new(); - hasher.update(payload.as_bytes()); - let hashed_payload = hasher.finalize(); - format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") - } - - pub fn get_string_to_sign(&self, algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(canonical_request.as_bytes()); - let canonical_request = hasher.finalize(); - format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") - } - - pub fn get_authorization_header(&self, algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { - let signed_headers = signed_headers.join(";"); - format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") - } + pub fn sign(&self, key: &[u8], msg: &[u8]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).unwrap(); + mac.update(msg); + let result = mac.finalize(); + result.into_bytes().to_vec() + } + + pub fn get_signing_key(&self, key: &str, date: &str, region: &str, service: &str) -> Vec { + let k_date = self.sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); + let k_region = self.sign(&k_date, region.as_bytes()); + let k_service = self.sign(&k_region, service.as_bytes()); + self.sign(&k_service, "aws4_request".as_bytes()) + } + + pub fn get_signature(&self, signing_key: &[u8], string_to_sign: &str) -> String { + let signature = self.sign(signing_key, string_to_sign.as_bytes()); + hex::encode(signature) + } + + pub fn get_canonical_request(&self, http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { + let canonical_headers = canonical_headers.join("\n"); + let signed_headers = signed_headers.join(";"); + let mut hasher = Sha256::new(); + hasher.update(payload.as_bytes()); + let hashed_payload = hasher.finalize(); + format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") + } + + pub fn get_string_to_sign(&self, algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(canonical_request.as_bytes()); + let canonical_request = hasher.finalize(); + format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") + } + + pub fn get_authorization_header(&self, algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { + let signed_headers = signed_headers.join(";"); + format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") + } pub fn infer_single(&self, input: &str, input_type: Option<&str>, truncate: Option<&str>) -> Result> { @@ -588,129 +588,129 @@ impl AmazonBedrockClient { // Step 0b. create payload let body = match model_provider { - "amazon" => ureq::json!({ - "inputText": input.to_owned(), - }), - "cohere" => ureq::json!({ - "texts": [input.to_owned()], - "input_type": input_type.unwrap_or("search_document"), - "truncate": truncate.unwrap_or("NONE") - }), - _ => ureq::json!({}) - }; - - // Step 0c. get date and time - - let current_time = chrono::Utc::now(); - let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); - let amazon_date = current_time.format("%Y%m%d").to_string(); - - // Step 1: create a canonical request - - let canonical_uri = format!("/model/{}/invoke", self.model_id); - let canonical_query_string = ""; - let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); - let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", self.model_id); - - let mut signed_headers = vec![ - "host", - "x-amz-date" - ]; - - if !self.aws_session_token.is_empty() { - signed_headers.push("x-amz-security-token"); - } - - let mut canonical_headers = vec![ - format!("host:{service_endpoint}"), - format!("x-amz-date:{amazon_time}") - ]; - - if !self.aws_session_token.is_empty() { - canonical_headers.push( - format!("x-amz-security-token:{}", self.aws_session_token) - ); - } - - let canonical_request = self.get_canonical_request( - "POST", - &canonical_uri, - canonical_query_string, - &canonical_headers, - &signed_headers, - &body.to_string() - ); - - // Step 2: create string to sign - - let algorithm = "AWS4-HMAC-SHA256"; - let service = "bedrock"; - let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); - - let string_to_sign = self.get_string_to_sign( - algorithm, - &amazon_time, - &credential_scope, - &canonical_request - ); - - // Step 3: calculate signature - - let signing_key = self.get_signing_key( - &self.aws_secret_access_key, - &amazon_date, - &self.region, - service - ); - - let signature = self.get_signature( - &signing_key, - &string_to_sign - ); - - // Step 4: add the signature to the request - - let authorization = self.get_authorization_header( - "AWS4-HMAC-SHA256", - &self.aws_access_key_id, - &credential_scope, - &signed_headers, - &signature - ); + "amazon" => ureq::json!({ + "inputText": input.to_owned(), + }), + "cohere" => ureq::json!({ + "texts": [input.to_owned()], + "input_type": input_type.unwrap_or("search_document"), + "truncate": truncate.unwrap_or("NONE") + }), + _ => ureq::json!({}) + }; + + // Step 0c. get date and time + + let current_time = chrono::Utc::now(); + let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); + let amazon_date = current_time.format("%Y%m%d").to_string(); + + // Step 1: create a canonical request + + let canonical_uri = format!("/model/{}/invoke", self.model_id); + let canonical_query_string = ""; + let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); + let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", self.model_id); + + let mut signed_headers = vec![ + "host", + "x-amz-date" + ]; + + if !self.aws_session_token.is_empty() { + signed_headers.push("x-amz-security-token"); + } + + let mut canonical_headers = vec![ + format!("host:{service_endpoint}"), + format!("x-amz-date:{amazon_time}") + ]; + + if !self.aws_session_token.is_empty() { + canonical_headers.push( + format!("x-amz-security-token:{}", self.aws_session_token) + ); + } + + let canonical_request = self.get_canonical_request( + "POST", + &canonical_uri, + canonical_query_string, + &canonical_headers, + &signed_headers, + &body.to_string() + ); + + // Step 2: create string to sign + + let algorithm = "AWS4-HMAC-SHA256"; + let service = "bedrock"; + let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); + + let string_to_sign = self.get_string_to_sign( + algorithm, + &amazon_time, + &credential_scope, + &canonical_request + ); + + // Step 3: calculate signature + + let signing_key = self.get_signing_key( + &self.aws_secret_access_key, + &amazon_date, + &self.region, + service + ); + + let signature = self.get_signature( + &signing_key, + &string_to_sign + ); + + // Step 4: add the signature to the request + + let authorization = self.get_authorization_header( + "AWS4-HMAC-SHA256", + &self.aws_access_key_id, + &credential_scope, + &signed_headers, + &signature + ); // Step 5: send the request - let request = ureq::post(&endpoint) - .set("Accept", "application/json") - .set("X-Amz-Date", &amazon_time) - .set("Authorization", &authorization); + let request = ureq::post(&endpoint) + .set("Accept", "application/json") + .set("X-Amz-Date", &amazon_time) + .set("Authorization", &authorization); - let request = if self.aws_session_token.is_empty() { - request.clone() - } else { - request.clone().set("X-Amz-Security-Token", &self.aws_session_token) - }; + let request = if self.aws_session_token.is_empty() { + request.clone() + } else { + request.clone().set("X-Amz-Security-Token", &self.aws_session_token) + }; - let response = request.clone() - .send_bytes( - body.to_string().as_bytes() - ) - .unwrap() - .into_string() - .unwrap(); + let response = request.clone() + .send_bytes( + body.to_string().as_bytes() + ) + .unwrap() + .into_string() + .unwrap(); - let data: serde_json::Value = serde_json::from_str(&response).unwrap(); + let data: serde_json::Value = serde_json::from_str(&response).unwrap(); AmazonBedrockClient::parse_single_response(self, &data) } pub fn parse_single_response(&self, value: &serde_json::Value) -> Result> { - let model_provider = self.model_id.split('.').next().unwrap().to_string(); + let model_provider = self.model_id.split('.').next().unwrap().to_string(); - let output: Result>; + let output: Result>; if model_provider == "amazon" { - output = value + output = value .get("embedding") .ok_or_else(|| Error::new_message("expected 'embedding' key in response body")) .and_then(|v| { @@ -765,7 +765,7 @@ impl AmazonBedrockClient { } else { todo!(); } - output + output } } diff --git a/src/clients_vtab.rs b/src/clients_vtab.rs index c9644a6..d38ea00 100644 --- a/src/clients_vtab.rs +++ b/src/clients_vtab.rs @@ -99,7 +99,7 @@ impl<'vtab> VTabWriteable<'vtab> for ClientsTable { "cohere" => Client::Cohere(CohereClient::new(name, None, None)?), "ollama" => Client::Ollama(OllamaClient::new(name, None)), "llamafile" => Client::Llamafile(LlamafileClient::new(None)), - "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new(name, None, None, None, None)?), + "bedrock" => Client::AmazonBedrock(AmazonBedrockClient::new(name, None, None, None, None)?), text => { return Err(Error::new_message(format!( "'{text}' is not a valid rembed client." diff --git a/src/lib.rs b/src/lib.rs index b688222..fc984dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,10 +130,10 @@ pub fn rembed( Client::Ollama(client) => client.infer_single(input)?, Client::Llamafile(client) => client.infer_single(input)?, Client::AmazonBedrock(client) => { - let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); - let truncate = values.get(3).and_then(|v| api::value_text(v).ok()); - client.infer_single(input, input_type, truncate)? - } + let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); + let truncate = values.get(3).and_then(|v| api::value_text(v).ok()); + client.infer_single(input, input_type, truncate)? + } Client::Nomic(client) => { let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); client.infer_single(input, input_type)? @@ -173,7 +173,7 @@ pub fn sqlite3_rembed_init(db: *mut sqlite3) -> Result<()> { )?; define_scalar_function_with_aux(db, "rembed", 2, rembed, flags, Rc::clone(&c))?; define_scalar_function_with_aux(db, "rembed", 3, rembed, flags, Rc::clone(&c))?; - define_scalar_function_with_aux(db, "rembed", 4, rembed, flags, Rc::clone(&c))?; + define_scalar_function_with_aux(db, "rembed", 4, rembed, flags, Rc::clone(&c))?; define_scalar_function( db, "rembed_client_options", From 5e068682ac75a89bc1d19bbc9765c5433855cfab Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 12:35:05 +0100 Subject: [PATCH 05/11] Added error handling when calling model --- src/clients.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 8a56d5f..d34d798 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -695,9 +695,19 @@ impl AmazonBedrockClient { .send_bytes( body.to_string().as_bytes() ) - .unwrap() + .map_err( + |error| + Error::new_message( + format!("Error sending HTTP request: {error}") + ) + )? .into_string() - .unwrap(); + .map_err( + |error| + Error::new_message( + format!("Error parsing HTTP response: {error}") + ) + )?; let data: serde_json::Value = serde_json::from_str(&response).unwrap(); From d22ad304b829be1226795d1cc46e73fbdc71e495 Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 16:23:54 +0100 Subject: [PATCH 06/11] Refactored to move SigV4 fns outside Bedrock client --- src/clients.rs | 128 ++++++++++++++++++++++++++----------------------- 1 file changed, 69 insertions(+), 59 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index d34d798..972e5f9 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -509,6 +509,51 @@ impl LlamafileClient { } } +/* AWS SigV4 helpe functions */ + +pub(crate) fn sign(key: &[u8], msg: &[u8]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).expect("Error when signing message"); + mac.update(msg); + let result = mac.finalize(); + result.into_bytes().to_vec() +} + +pub(crate) fn get_signing_key(key: &str, date: &str, region: &str, service: &str) -> Vec { + let k_date = sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); + let k_region = sign(&k_date, region.as_bytes()); + let k_service = sign(&k_region, service.as_bytes()); + sign(&k_service, "aws4_request".as_bytes()) +} + +pub(crate) fn get_signature(signing_key: &[u8], string_to_sign: &str) -> String { + let signature = sign(signing_key, string_to_sign.as_bytes()); + hex::encode(signature) +} + + +pub(crate) fn get_canonical_request(http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { + let canonical_headers = canonical_headers.join("\n"); + let signed_headers = signed_headers.join(";"); + let mut hasher = Sha256::new(); + hasher.update(payload.as_bytes()); + let hashed_payload = hasher.finalize(); + format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") +} + +pub(crate) fn get_string_to_sign(algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(canonical_request.as_bytes()); + let canonical_request = hasher.finalize(); + format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") +} + +pub(crate) fn get_authorization_header(algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { + let signed_headers = signed_headers.join(";"); + format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") +} + +/* Amazon Bedrock Client */ + #[derive(Clone)] pub struct AmazonBedrockClient { model_id: String, @@ -525,10 +570,10 @@ impl AmazonBedrockClient { model_id: model_id.into(), region: region.unwrap_or(DEFAULT_AWS_REGION.to_owned()), aws_access_key_id: aws_access_key_id.unwrap_or( - std::env::var("AWS_ACCESS_KEY_ID").unwrap() + std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_default() ), aws_secret_access_key: aws_secret_access_key.unwrap_or( - std::env::var("AWS_SECRET_ACCESS_KEY").unwrap() + std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default() ), aws_session_token: aws_session_token.unwrap_or( std::env::var("AWS_SESSION_TOKEN").unwrap_or_default() @@ -536,46 +581,6 @@ impl AmazonBedrockClient { }) } - pub fn sign(&self, key: &[u8], msg: &[u8]) -> Vec { - let mut mac = HmacSha256::new_from_slice(key).unwrap(); - mac.update(msg); - let result = mac.finalize(); - result.into_bytes().to_vec() - } - - pub fn get_signing_key(&self, key: &str, date: &str, region: &str, service: &str) -> Vec { - let k_date = self.sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); - let k_region = self.sign(&k_date, region.as_bytes()); - let k_service = self.sign(&k_region, service.as_bytes()); - self.sign(&k_service, "aws4_request".as_bytes()) - } - - pub fn get_signature(&self, signing_key: &[u8], string_to_sign: &str) -> String { - let signature = self.sign(signing_key, string_to_sign.as_bytes()); - hex::encode(signature) - } - - pub fn get_canonical_request(&self, http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { - let canonical_headers = canonical_headers.join("\n"); - let signed_headers = signed_headers.join(";"); - let mut hasher = Sha256::new(); - hasher.update(payload.as_bytes()); - let hashed_payload = hasher.finalize(); - format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") - } - - pub fn get_string_to_sign(&self, algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(canonical_request.as_bytes()); - let canonical_request = hasher.finalize(); - format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") - } - - pub fn get_authorization_header(&self, algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { - let signed_headers = signed_headers.join(";"); - format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") - } - pub fn infer_single(&self, input: &str, input_type: Option<&str>, truncate: Option<&str>) -> Result> { // Step 0a. extract model provider @@ -583,7 +588,7 @@ impl AmazonBedrockClient { let model_provider = self.model_id .split('.') .next() - .unwrap(); + .expect("Error getting model provider"); // Step 0b. create payload @@ -632,7 +637,7 @@ impl AmazonBedrockClient { ); } - let canonical_request = self.get_canonical_request( + let canonical_request = get_canonical_request( "POST", &canonical_uri, canonical_query_string, @@ -647,7 +652,7 @@ impl AmazonBedrockClient { let service = "bedrock"; let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); - let string_to_sign = self.get_string_to_sign( + let string_to_sign = get_string_to_sign( algorithm, &amazon_time, &credential_scope, @@ -656,21 +661,21 @@ impl AmazonBedrockClient { // Step 3: calculate signature - let signing_key = self.get_signing_key( + let signing_key = get_signing_key( &self.aws_secret_access_key, &amazon_date, &self.region, service ); - let signature = self.get_signature( + let signature = get_signature( &signing_key, &string_to_sign ); // Step 4: add the signature to the request - let authorization = self.get_authorization_header( + let authorization = get_authorization_header( "AWS4-HMAC-SHA256", &self.aws_access_key_id, &credential_scope, @@ -686,12 +691,12 @@ impl AmazonBedrockClient { .set("Authorization", &authorization); let request = if self.aws_session_token.is_empty() { - request.clone() + request } else { - request.clone().set("X-Amz-Security-Token", &self.aws_session_token) + request.set("X-Amz-Security-Token", &self.aws_session_token) }; - let response = request.clone() + let response = request .send_bytes( body.to_string().as_bytes() ) @@ -716,11 +721,13 @@ impl AmazonBedrockClient { pub fn parse_single_response(&self, value: &serde_json::Value) -> Result> { - let model_provider = self.model_id.split('.').next().unwrap().to_string(); + let model_provider = self.model_id + .split('.') + .next() + .expect("Error getting model provider"); - let output: Result>; if model_provider == "amazon" { - output = value + value .get("embedding") .ok_or_else(|| Error::new_message("expected 'embedding' key in response body")) .and_then(|v| { @@ -740,9 +747,9 @@ impl AmazonBedrockClient { .map(|f| f as f32) }) .collect() - }); + }) } else if model_provider == "cohere" { - output = value + value .get("embeddings") .ok_or_else(|| Error::new_message("expected 'embeddings' key in response body")) .and_then(|v: &serde_json::Value| { @@ -771,11 +778,14 @@ impl AmazonBedrockClient { .map(|f| f as f32) }) .collect() - }); + }) } else { - todo!(); + Err( + Error::new_message( + format!("Model provider '{model_provider}' is not supported!") + ) + ) } - output } } From bb740dd5bca53695370cb8bafe5b6d44c87c4f7d Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 18:01:37 +0100 Subject: [PATCH 07/11] Refactored infer_single to pass inference params as JSON --- src/clients.rs | 37 ++++++++++++++++++++++++++++--------- src/lib.rs | 6 ++---- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 972e5f9..8b988d4 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -564,6 +564,18 @@ pub struct AmazonBedrockClient { } const DEFAULT_AWS_REGION: &str = "us-east-1"; +fn merge(a: &mut serde_json::Value, b: serde_json::Value) { + match (a, b) { + (a @ &mut serde_json::Value::Object(_), serde_json::Value::Object(b)) => { + let a = a.as_object_mut().unwrap(); + for (k, v) in b { + merge(a.entry(k).or_insert(serde_json::Value::Null), v); + } + } + (a, b) => *a = b, + } +} + impl AmazonBedrockClient { pub fn new>(model_id: S, region: Option, aws_access_key_id: Option, aws_secret_access_key: Option, aws_session_token: Option) -> Result { Ok(Self { @@ -581,30 +593,37 @@ impl AmazonBedrockClient { }) } - pub fn infer_single(&self, input: &str, input_type: Option<&str>, truncate: Option<&str>) -> Result> { - - // Step 0a. extract model provider + pub fn infer_single(&self, input: &str, inference_options: Option) -> Result> { + + // Step 0a: extract model provider let model_provider = self.model_id .split('.') .next() .expect("Error getting model provider"); - // Step 0b. create payload + // Step 0b: create payload - let body = match model_provider { + let mut body = match model_provider { "amazon" => ureq::json!({ "inputText": input.to_owned(), }), "cohere" => ureq::json!({ - "texts": [input.to_owned()], - "input_type": input_type.unwrap_or("search_document"), - "truncate": truncate.unwrap_or("NONE") + "texts": [ + input.to_owned() + ], }), _ => ureq::json!({}) }; - // Step 0c. get date and time + let inference_options = match inference_options { + Some(v) => v, + None => ureq::json!({}) + }; + + merge(&mut body, inference_options); + + // Step 0c: get date and time let current_time = chrono::Utc::now(); let amazon_time = current_time.format("%Y%m%dT%H%M%SZ").to_string(); diff --git a/src/lib.rs b/src/lib.rs index fc984dd..7d2265b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,9 +130,8 @@ pub fn rembed( Client::Ollama(client) => client.infer_single(input)?, Client::Llamafile(client) => client.infer_single(input)?, Client::AmazonBedrock(client) => { - let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); - let truncate = values.get(3).and_then(|v| api::value_text(v).ok()); - client.infer_single(input, input_type, truncate)? + let inference_options = values.get(2).and_then(|v| api::value_json(v).ok()); + client.infer_single(input, inference_options)? } Client::Nomic(client) => { let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); @@ -173,7 +172,6 @@ pub fn sqlite3_rembed_init(db: *mut sqlite3) -> Result<()> { )?; define_scalar_function_with_aux(db, "rembed", 2, rembed, flags, Rc::clone(&c))?; define_scalar_function_with_aux(db, "rembed", 3, rembed, flags, Rc::clone(&c))?; - define_scalar_function_with_aux(db, "rembed", 4, rembed, flags, Rc::clone(&c))?; define_scalar_function( db, "rembed_client_options", From 8a0f14053b051bdcd01625bf6436810d322e63b3 Mon Sep 17 00:00:00 2001 From: JGalego Date: Tue, 6 Aug 2024 18:08:43 +0100 Subject: [PATCH 08/11] Minor bugfix (set default input_type for Cohere Embed) --- src/clients.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/clients.rs b/src/clients.rs index 8b988d4..b5e1887 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -612,6 +612,7 @@ impl AmazonBedrockClient { "texts": [ input.to_owned() ], + "input_type": "search_document" }), _ => ureq::json!({}) }; From eff315e76ca8f3a10eb102b33e266b7ef20d17bc Mon Sep 17 00:00:00 2001 From: JGalego Date: Wed, 7 Aug 2024 21:36:15 +0100 Subject: [PATCH 09/11] Added fix for escaping 'bad' chars in model_id --- Cargo.lock | 5 +++-- Cargo.toml | 1 + src/clients.rs | 9 +++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad6dba7..a405d35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -715,6 +715,7 @@ dependencies = [ "sha2", "sqlite-loadable", "ureq", + "url", "zerocopy", ] @@ -846,9 +847,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", diff --git a/Cargo.toml b/Cargo.toml index 63efe57..f3a94b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ chrono = "0.4.38" hex = "0.4.3" hmac = "0.12.1" sha2 = "0.10.8" +url = "2.5.2" [lib] crate-type=["cdylib", "staticlib", "lib"] diff --git a/src/clients.rs b/src/clients.rs index b5e1887..21ecfec 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -632,10 +632,15 @@ impl AmazonBedrockClient { // Step 1: create a canonical request - let canonical_uri = format!("/model/{}/invoke", self.model_id); + // Fix: escape 'bad' characters (*, /, :) in model_id + // https://github.com/curl/curl/issues/11794 + let model_id: String = url::form_urlencoded::byte_serialize(self.model_id.as_bytes()).collect(); + let escaped_model_id: String = url::form_urlencoded::byte_serialize(model_id.as_bytes()).collect(); + + let canonical_uri = format!("/model/{}/invoke", escaped_model_id); let canonical_query_string = ""; let service_endpoint: String = format!("bedrock-runtime.{}.amazonaws.com", self.region); - let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", self.model_id); + let endpoint: String = format!("https://{service_endpoint}/model/{}/invoke", model_id); let mut signed_headers = vec![ "host", From c97b94fedd9ac83ea99c56c8ce24952f12576c8a Mon Sep 17 00:00:00 2001 From: JGalego Date: Thu, 8 Aug 2024 15:40:42 +0100 Subject: [PATCH 10/11] Added missing fn docs --- src/clients.rs | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 21ecfec..344ec14 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -511,27 +511,31 @@ impl LlamafileClient { /* AWS SigV4 helpe functions */ +/// Computes HMAC on a message by using the SHA256 algorithm with the provided signing key. pub(crate) fn sign(key: &[u8], msg: &[u8]) -> Vec { let mut mac = HmacSha256::new_from_slice(key).expect("Error when signing message"); mac.update(msg); let result = mac.finalize(); result.into_bytes().to_vec() } - -pub(crate) fn get_signing_key(key: &str, date: &str, region: &str, service: &str) -> Vec { +/// Derives a signing key by performing a succession of keyed hash operations (HMAC operations) +/// on the request date, Region, and service, with the AWS secret access key as the key for the +/// initial hashing operation. +pub(crate) fn derive_signing_key(key: &str, date: &str, region: &str, service: &str) -> Vec { let k_date = sign(format!("AWS4{key}").as_bytes(), date.as_bytes()); let k_region = sign(&k_date, region.as_bytes()); let k_service = sign(&k_region, service.as_bytes()); sign(&k_service, "aws4_request".as_bytes()) } -pub(crate) fn get_signature(signing_key: &[u8], string_to_sign: &str) -> String { +/// Calculates the signature by performing a keyed hash operation on the string to sign. +pub(crate) fn calculate_signature(signing_key: &[u8], string_to_sign: &str) -> String { let signature = sign(signing_key, string_to_sign.as_bytes()); hex::encode(signature) } - -pub(crate) fn get_canonical_request(http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { +/// Arranges the contents of the request (host, action, headers, &c.) into a standard canonical format. +pub(crate) fn create_canonical_request(http_verb: &str, canonical_uri: &str, canonical_query_string: &str, canonical_headers: &[String], signed_headers: &[&str], payload: &str) -> String { let canonical_headers = canonical_headers.join("\n"); let signed_headers = signed_headers.join(";"); let mut hasher = Sha256::new(); @@ -540,14 +544,17 @@ pub(crate) fn get_canonical_request(http_verb: &str, canonical_uri: &str, canoni format!("{http_verb}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{hashed_payload:x}") } -pub(crate) fn get_string_to_sign(algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { +/// Creates a string to sign with the canonical request and extra information such as the algorithm, +/// request date, credential scope, and the hash of the canonical request. +pub(crate) fn create_string_to_sign(algorithm: &str, timestamp: &str, credential_scope: &str, canonical_request: &str) -> String { let mut hasher = Sha256::new(); hasher.update(canonical_request.as_bytes()); let canonical_request = hasher.finalize(); format!("{algorithm}\n{timestamp}\n{credential_scope}\n{canonical_request:x}") } -pub(crate) fn get_authorization_header(algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { +/// Creates the Authorization header for the request. +pub(crate) fn create_authorization_header(algorithm: &str, credential: &str, scope: &str, signed_headers: &[&str], signature: &str) -> String { let signed_headers = signed_headers.join(";"); format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") } @@ -562,6 +569,7 @@ pub struct AmazonBedrockClient { aws_secret_access_key: String, aws_session_token: String } +const HASH_ALGORITHM: &str = "AWS4-HMAC-SHA256"; const DEFAULT_AWS_REGION: &str = "us-east-1"; fn merge(a: &mut serde_json::Value, b: serde_json::Value) { @@ -662,7 +670,7 @@ impl AmazonBedrockClient { ); } - let canonical_request = get_canonical_request( + let canonical_request = create_canonical_request( "POST", &canonical_uri, canonical_query_string, @@ -673,11 +681,11 @@ impl AmazonBedrockClient { // Step 2: create string to sign - let algorithm = "AWS4-HMAC-SHA256"; + let algorithm = HASH_ALGORITHM; let service = "bedrock"; let credential_scope = format!("{amazon_date}/{}/{service}/aws4_request", self.region); - let string_to_sign = get_string_to_sign( + let string_to_sign = create_string_to_sign( algorithm, &amazon_time, &credential_scope, @@ -686,22 +694,22 @@ impl AmazonBedrockClient { // Step 3: calculate signature - let signing_key = get_signing_key( + let signing_key = derive_signing_key( &self.aws_secret_access_key, &amazon_date, &self.region, service ); - let signature = get_signature( + let signature = calculate_signature( &signing_key, &string_to_sign ); // Step 4: add the signature to the request - let authorization = get_authorization_header( - "AWS4-HMAC-SHA256", + let authorization = create_authorization_header( + HASH_ALGORITHM, &self.aws_access_key_id, &credential_scope, &signed_headers, From 3e18391acc35ea7eef01f98afad6ef1b3f0371ce Mon Sep 17 00:00:00 2001 From: JGalego Date: Thu, 8 Aug 2024 15:42:40 +0100 Subject: [PATCH 11/11] Minor changes --- src/clients.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 344ec14..ab30aa1 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -509,7 +509,7 @@ impl LlamafileClient { } } -/* AWS SigV4 helpe functions */ +/* AWS SigV4 */ /// Computes HMAC on a message by using the SHA256 algorithm with the provided signing key. pub(crate) fn sign(key: &[u8], msg: &[u8]) -> Vec { @@ -559,7 +559,7 @@ pub(crate) fn create_authorization_header(algorithm: &str, credential: &str, sco format!("{algorithm} Credential={credential}/{scope}, SignedHeaders={signed_headers}, Signature={signature}") } -/* Amazon Bedrock Client */ +/* Amazon Bedrock */ #[derive(Clone)] pub struct AmazonBedrockClient {