From bc8d4c4523cd85bd19ecb0b923d101eafda1788e Mon Sep 17 00:00:00 2001 From: Rody Davis Date: Thu, 29 Aug 2024 14:53:22 -0700 Subject: [PATCH] Adding Google AI support --- README.md | 1 + src/clients.rs | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 10 +++++- 3 files changed, 96 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d59a4fc..bed524f 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Other pre-defined clients include: | `mixedbread` | [MixedBread](https://www.mixedbread.ai/api-reference#quick-start-guide) | `https://api.mixedbread.ai/v1/embeddings/` | `MIXEDBREAD_API_KEY` | | `llamafile` | [llamafile](https://github.com/Mozilla-Ocho/llamafile) | `http://localhost:8080/embedding` | None | | `ollama` | [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings) | `http://localhost:11434/api/embeddings` | None | +| `googleai` | [GoogleAI](https://ai.google.dev/gemini-api) | `https://generativelanguage.googleapis.com/v1beta/models` | `GOOGLE_AI_API_KEY` | Different client options can be specified with `remebed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use: diff --git a/src/clients.rs b/src/clients.rs index 5f83b9a..6b2fe07 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -504,6 +504,91 @@ impl LlamafileClient { } } +#[derive(Clone)] +pub struct GoogleAiClient { + model: String, + url: String, + key: String, +} +const DEFAULT_GOOGLE_AI_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models"; +const DEFAULT_GOOGLE_AI_API_KEY_ENV: &str = "GOOGLE_AI_API_KEY"; + +impl GoogleAiClient { + pub fn new>( + model: S, + url: Option, + key: Option, + ) -> Result { + Ok(Self { + model: model.into(), + url: url.unwrap_or(DEFAULT_GOOGLE_AI_URL.to_owned()), + key: match key { + Some(key) => key, + None => try_env_var(DEFAULT_GOOGLE_AI_API_KEY_ENV)?, + }, + }) + } + pub fn infer_single(&self, input: &str) -> Result> { + let body = serde_json::json!({ + "model": format!("models/{}", self.model).clone(), + "content": { + "parts": [ + { + "text": input, + } + ] + } + }); + + let target_url = format!("{}/{}:embedContent?key={}", self.url, self.model, self.key); + + let data: serde_json::Value = ureq::post(&target_url) + .set("Content-Type", "application/json") + .send_bytes( + serde_json::to_vec(&body) + .map_err(|error| { + Error::new_message(format!("Error serializing body to JSON: {error}")) + })? + .as_ref(), + ) + .map_err(|error| Error::new_message(format!("Error sending HTTP request: {error}")))? + .into_json() + .map_err(|error| { + Error::new_message(format!("Error parsing HTTP response as JSON: {error}")) + })?; + GoogleAiClient::parse_single_response(data) + } + + pub fn parse_single_response(value: serde_json::Value) -> Result> { + value + .get("embedding") + .ok_or_else(|| Error::new_message("expected 'embedding' key in response body")) + .and_then(|v| { + v.get("values").ok_or_else(|| { + Error::new_message("expected 'embedding.values' path in response body") + }) + }) + .and_then(|v| { + v.as_array().ok_or_else(|| { + Error::new_message("expected 'embedding.values' path to be an array") + }) + }) + .and_then(|arr| { + arr.iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message( + "expected 'embedding.values' array to contain floats", + ) + }) + .map(|f| f as f32) + }) + .collect() + }) + } +} + #[derive(Clone)] pub enum Client { OpenAI(OpenAiClient), @@ -513,4 +598,5 @@ pub enum Client { Llamafile(LlamafileClient), Jina(JinaClient), Mixedbread(MixedbreadClient), + GoogleAI(GoogleAiClient), } diff --git a/src/lib.rs b/src/lib.rs index 1924525..3bf9575 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; -use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient}; +use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, GoogleAiClient}; use clients_vtab::ClientsTable; use sqlite_loadable::{ api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex, @@ -91,6 +91,13 @@ pub fn rembed_client_options( options.get("url").cloned(), )), "llamafile" => Client::Llamafile(LlamafileClient::new(options.get("url").cloned())), + "googleai" => Client::GoogleAI(GoogleAiClient::new( + options + .get("model") + .ok_or_else(|| Error::new_message("'model' option is required"))?, + options.get("url").cloned(), + options.get("key").cloned(), + )?), format => return Err(Error::new_message(format!("Unknown format '{format}'"))), }; @@ -126,6 +133,7 @@ pub fn rembed( let input_type = values.get(2).and_then(|v| api::value_text(v).ok()); client.infer_single(input, input_type)? } + Client::GoogleAI(client) => client.infer_single(input)?, }; api::result_blob(context, embedding.as_bytes());