Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Other pre-defined clients include:
| `mixedbread` | [MixedBread](https://www.mixedbread.ai/api-reference#quick-start-guide) | `https://api.mixedbread.ai/v1/embeddings/` | `MIXEDBREAD_API_KEY` |
| `llamafile` | [llamafile](https://github.com/Mozilla-Ocho/llamafile) | `http://localhost:8080/embedding` | None |
| `ollama` | [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings) | `http://localhost:11434/api/embeddings` | None |
| `googleai` | [GoogleAI](https://ai.google.dev/gemini-api) | `https://generativelanguage.googleapis.com/v1beta/models` | `GOOGLE_AI_API_KEY` |

Different client options can be specified with `remebed_client_options()`. For example, if you have a different OpenAI-compatible service you want to use, then you can use:

Expand Down
86 changes: 86 additions & 0 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,91 @@ impl LlamafileClient {
}
}

#[derive(Clone)]
pub struct GoogleAiClient {
model: String,
url: String,
key: String,
}
const DEFAULT_GOOGLE_AI_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
const DEFAULT_GOOGLE_AI_API_KEY_ENV: &str = "GOOGLE_AI_API_KEY";

impl GoogleAiClient {
pub fn new<S: Into<String>>(
model: S,
url: Option<String>,
key: Option<String>,
) -> Result<Self> {
Ok(Self {
model: model.into(),
url: url.unwrap_or(DEFAULT_GOOGLE_AI_URL.to_owned()),
key: match key {
Some(key) => key,
None => try_env_var(DEFAULT_GOOGLE_AI_API_KEY_ENV)?,
},
})
}
pub fn infer_single(&self, input: &str) -> Result<Vec<f32>> {
let body = serde_json::json!({
"model": format!("models/{}", self.model).clone(),
"content": {
"parts": [
{
"text": input,
}
]
}
});

let target_url = format!("{}/{}:embedContent?key={}", self.url, self.model, self.key);

let data: serde_json::Value = ureq::post(&target_url)
.set("Content-Type", "application/json")
.send_bytes(
serde_json::to_vec(&body)
.map_err(|error| {
Error::new_message(format!("Error serializing body to JSON: {error}"))
})?
.as_ref(),
)
.map_err(|error| Error::new_message(format!("Error sending HTTP request: {error}")))?
.into_json()
.map_err(|error| {
Error::new_message(format!("Error parsing HTTP response as JSON: {error}"))
})?;
GoogleAiClient::parse_single_response(data)
}

pub fn parse_single_response(value: serde_json::Value) -> Result<Vec<f32>> {
value
.get("embedding")
.ok_or_else(|| Error::new_message("expected 'embedding' key in response body"))
.and_then(|v| {
v.get("values").ok_or_else(|| {
Error::new_message("expected 'embedding.values' path in response body")
})
})
.and_then(|v| {
v.as_array().ok_or_else(|| {
Error::new_message("expected 'embedding.values' path to be an array")
})
})
.and_then(|arr| {
arr.iter()
.map(|v| {
v.as_f64()
.ok_or_else(|| {
Error::new_message(
"expected 'embedding.values' array to contain floats",
)
})
.map(|f| f as f32)
})
.collect()
})
}
}

#[derive(Clone)]
pub enum Client {
OpenAI(OpenAiClient),
Expand All @@ -513,4 +598,5 @@ pub enum Client {
Llamafile(LlamafileClient),
Jina(JinaClient),
Mixedbread(MixedbreadClient),
GoogleAI(GoogleAiClient),
}
10 changes: 9 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient};
use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient, GoogleAiClient};
use clients_vtab::ClientsTable;
use sqlite_loadable::{
api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex,
Expand Down Expand Up @@ -91,6 +91,13 @@ pub fn rembed_client_options(
options.get("url").cloned(),
)),
"llamafile" => Client::Llamafile(LlamafileClient::new(options.get("url").cloned())),
"googleai" => Client::GoogleAI(GoogleAiClient::new(
options
.get("model")
.ok_or_else(|| Error::new_message("'model' option is required"))?,
options.get("url").cloned(),
options.get("key").cloned(),
)?),
format => return Err(Error::new_message(format!("Unknown format '{format}'"))),
};

Expand Down Expand Up @@ -126,6 +133,7 @@ pub fn rembed(
let input_type = values.get(2).and_then(|v| api::value_text(v).ok());
client.infer_single(input, input_type)?
}
Client::GoogleAI(client) => client.infer_single(input)?,
};

api::result_blob(context, embedding.as_bytes());
Expand Down