Skip to content

Commit 3d1410e

Browse files
Merge pull request #107 from SyedZawwarAhmed/feat/google-ai-embeddings
Feat: add Google AI embedding provider support for vector db
2 parents e7d51fa + ab2bbdc commit 3d1410e

File tree

4 files changed

+488
-0
lines changed

4 files changed

+488
-0
lines changed

packages/core/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ export * from './subsystems/IO/VectorDB.service/connectors/MilvusVectorDB.class'
166166
export * from './subsystems/IO/VectorDB.service/connectors/PineconeVectorDB.class';
167167
export * from './subsystems/IO/VectorDB.service/connectors/RAMVecrtorDB.class';
168168
export * from './subsystems/IO/VectorDB.service/embed/BaseEmbedding';
169+
export * from './subsystems/IO/VectorDB.service/embed/GoogleEmbedding';
169170
export * from './subsystems/IO/VectorDB.service/embed/index';
170171
export * from './subsystems/IO/VectorDB.service/embed/OpenAIEmbedding';
171172
export * from './subsystems/LLMManager/LLM.service/connectors/Anthropic.class';
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import { GoogleGenerativeAI } from '@google/generative-ai';
2+
import { BaseEmbedding, TEmbeddings } from './BaseEmbedding';
3+
import { AccessCandidate } from '@sre/Security/AccessControl/AccessCandidate.class';
4+
import { getLLMCredentials } from '@sre/LLMManager/LLM.service/LLMCredentials.helper';
5+
import { TLLMCredentials, TLLMModel, BasicCredentials } from '@sre/types/LLM.types';
6+
7+
const DEFAULT_MODEL = 'gemini-embedding-001';
8+
9+
export class GoogleEmbeds extends BaseEmbedding {
10+
protected client: GoogleGenerativeAI;
11+
12+
public static models = ['gemini-embedding-001'];
13+
public canSpecifyDimensions = true;
14+
15+
constructor(private settings?: Partial<TEmbeddings>) {
16+
super({ model: settings?.model ?? DEFAULT_MODEL, ...settings });
17+
}
18+
19+
async embedTexts(texts: string[], candidate: AccessCandidate): Promise<number[][]> {
20+
const batches = this.chunkArr(this.processTexts(texts), this.chunkSize);
21+
22+
const batchRequests = batches.map((batch) => {
23+
return this.embed(batch, candidate);
24+
});
25+
const batchResponses = await Promise.all(batchRequests);
26+
27+
const embeddings: number[][] = [];
28+
for (let i = 0; i < batchResponses.length; i += 1) {
29+
const batch = batches[i];
30+
const batchResponse = batchResponses[i];
31+
for (let j = 0; j < batch.length; j += 1) {
32+
embeddings.push(batchResponse[j]);
33+
}
34+
}
35+
return embeddings;
36+
}
37+
38+
async embedText(text: string, candidate: AccessCandidate): Promise<number[]> {
39+
const processedText = this.processTexts([text])[0];
40+
const embeddings = await this.embed([processedText], candidate);
41+
return embeddings[0];
42+
}
43+
44+
protected async embed(texts: string[], candidate: AccessCandidate): Promise<number[][]> {
45+
let apiKey: string | undefined;
46+
47+
// Try to get from credentials first
48+
try {
49+
const modelInfo: TLLMModel = {
50+
provider: 'GoogleAI',
51+
modelId: this.model,
52+
credentials: this.settings?.credentials as unknown as TLLMCredentials,
53+
};
54+
const credentials = await getLLMCredentials(candidate, modelInfo);
55+
apiKey = (credentials as BasicCredentials)?.apiKey;
56+
} catch (e) {
57+
// If credential system fails, fall back to environment variable
58+
}
59+
60+
// Fall back to environment variable if not found in credentials
61+
if (!apiKey) {
62+
apiKey = process.env.GOOGLE_AI_API_KEY;
63+
}
64+
65+
if (!apiKey) {
66+
throw new Error('Please provide an API key for Google AI embeddings via credentials or GOOGLE_AI_API_KEY environment variable');
67+
}
68+
69+
if (!this.client) {
70+
this.client = new GoogleGenerativeAI(apiKey);
71+
}
72+
73+
try {
74+
const model = this.client.getGenerativeModel({ model: this.model });
75+
76+
const embeddings: number[][] = [];
77+
78+
for (const text of texts) {
79+
const result = await model.embedContent(text);
80+
if (result?.embedding?.values) {
81+
embeddings.push(result.embedding.values);
82+
} else {
83+
throw new Error('Invalid embedding response from Google AI');
84+
}
85+
}
86+
87+
return embeddings;
88+
} catch (e) {
89+
throw new Error(`Google Embeddings API error: ${e.message || e}`);
90+
}
91+
}
92+
}

packages/core/src/subsystems/IO/VectorDB.service/embed/index.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { OpenAIEmbeds } from './OpenAIEmbedding';
2+
import { GoogleEmbeds } from './GoogleEmbedding';
23
import { TEmbeddings } from './BaseEmbedding';
34

45
// a factory to get the correct embedding provider based on the provider name
@@ -7,6 +8,10 @@ const supportedProviders = {
78
embedder: OpenAIEmbeds,
89
models: OpenAIEmbeds.models,
910
},
11+
GoogleAI: {
12+
embedder: GoogleEmbeds,
13+
models: GoogleEmbeds.models,
14+
},
1015
} as const;
1116

1217
export type SupportedProviders = keyof typeof supportedProviders;

0 commit comments

Comments
 (0)