Skip to content

Commit 494070f

Browse files
authored
chore: Add new session-level service for getting embeddings of a specific collection MCP-246 (#626)
1 parent fdedb02 commit 494070f

25 files changed

+852
-89
lines changed

src/common/config.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ const OPTIONS = {
5858
boolean: [
5959
"apiDeprecationErrors",
6060
"apiStrict",
61+
"disableEmbeddingsValidation",
6162
"help",
6263
"indexCheck",
6364
"ipv6",
@@ -183,6 +184,7 @@ export interface UserConfig extends CliOptions {
183184
maxBytesPerQuery: number;
184185
atlasTemporaryDatabaseUserLifetimeMs: number;
185186
voyageApiKey: string;
187+
disableEmbeddingsValidation: boolean;
186188
vectorSearchDimensions: number;
187189
vectorSearchSimilarityFunction: "cosine" | "euclidean" | "dotProduct";
188190
}
@@ -216,6 +218,7 @@ export const defaultUserConfig: UserConfig = {
216218
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
217219
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
218220
voyageApiKey: "",
221+
disableEmbeddingsValidation: false,
219222
vectorSearchDimensions: 1024,
220223
vectorSearchSimilarityFunction: "euclidean",
221224
};

src/common/connectionManager.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export interface ConnectionState {
3232
connectedAtlasCluster?: AtlasClusterConnectionInfo;
3333
}
3434

35+
const MCP_TEST_DATABASE = "#mongodb-mcp";
3536
export class ConnectionStateConnected implements ConnectionState {
3637
public tag = "connected" as const;
3738

@@ -46,11 +47,11 @@ export class ConnectionStateConnected implements ConnectionState {
4647
public async isSearchSupported(): Promise<boolean> {
4748
if (this._isSearchSupported === undefined) {
4849
try {
49-
const dummyDatabase = "test";
50-
const dummyCollection = "test";
5150
// If a cluster supports search indexes, the call below will succeed
52-
// with a cursor otherwise will throw an Error
53-
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
51+
// with a cursor otherwise will throw an Error.
52+
// the Search Index Management Service might not be ready yet, but
53+
// we assume that the agent can retry in that situation.
54+
await this.serviceProvider.getSearchIndexes(MCP_TEST_DATABASE, "test");
5455
this._isSearchSupported = true;
5556
} catch {
5657
this._isSearchSupported = false;

src/common/errors.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export enum ErrorCodes {
33
MisconfiguredConnectionString = 1_000_001,
44
ForbiddenCollscan = 1_000_002,
55
ForbiddenWriteOperation = 1_000_003,
6+
AtlasSearchNotSupported = 1_000_004,
67
}
78

89
export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
2+
import { BSON, type Document } from "bson";
3+
import type { UserConfig } from "../config.js";
4+
import type { ConnectionManager } from "../connectionManager.js";
5+
6+
export type VectorFieldIndexDefinition = {
7+
type: "vector";
8+
path: string;
9+
numDimensions: number;
10+
quantization: "none" | "scalar" | "binary";
11+
similarity: "euclidean" | "cosine" | "dotProduct";
12+
};
13+
14+
export type EmbeddingNamespace = `${string}.${string}`;
15+
export class VectorSearchEmbeddingsManager {
16+
constructor(
17+
private readonly config: UserConfig,
18+
private readonly connectionManager: ConnectionManager,
19+
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map()
20+
) {
21+
connectionManager.events.on("connection-close", () => {
22+
this.embeddings.clear();
23+
});
24+
}
25+
26+
cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void {
27+
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
28+
this.embeddings.delete(embeddingDefKey);
29+
}
30+
31+
async embeddingsForNamespace({
32+
database,
33+
collection,
34+
}: {
35+
database: string;
36+
collection: string;
37+
}): Promise<VectorFieldIndexDefinition[]> {
38+
const provider = await this.assertAtlasSearchIsAvailable();
39+
if (!provider) {
40+
return [];
41+
}
42+
43+
// We only need the embeddings for validation now, so don't query them if
44+
// validation is disabled.
45+
if (this.config.disableEmbeddingsValidation) {
46+
return [];
47+
}
48+
49+
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
50+
const definition = this.embeddings.get(embeddingDefKey);
51+
52+
if (!definition) {
53+
const allSearchIndexes = await provider.getSearchIndexes(database, collection);
54+
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch");
55+
const vectorFields = vectorSearchIndexes
56+
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
57+
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? [])
58+
.filter((field) => this.isVectorFieldIndexDefinition(field));
59+
60+
this.embeddings.set(embeddingDefKey, vectorFields);
61+
return vectorFields;
62+
}
63+
64+
return definition;
65+
}
66+
67+
async findFieldsWithWrongEmbeddings(
68+
{
69+
database,
70+
collection,
71+
}: {
72+
database: string;
73+
collection: string;
74+
},
75+
document: Document
76+
): Promise<VectorFieldIndexDefinition[]> {
77+
const provider = await this.assertAtlasSearchIsAvailable();
78+
if (!provider) {
79+
return [];
80+
}
81+
82+
// While we can do our best effort to ensure that the embedding validation is correct
83+
// based on https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-quantization/
84+
// it's a complex process so we will also give the user the ability to disable this validation
85+
if (this.config.disableEmbeddingsValidation) {
86+
return [];
87+
}
88+
89+
const embeddings = await this.embeddingsForNamespace({ database, collection });
90+
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document));
91+
}
92+
93+
private async assertAtlasSearchIsAvailable(): Promise<NodeDriverServiceProvider | null> {
94+
const connectionState = this.connectionManager.currentConnectionState;
95+
if (connectionState.tag === "connected") {
96+
if (await connectionState.isSearchSupported()) {
97+
return connectionState.serviceProvider;
98+
}
99+
}
100+
101+
return null;
102+
}
103+
104+
private isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition {
105+
return doc["type"] === "vector";
106+
}
107+
108+
private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
109+
const fieldPath = definition.path.split(".");
110+
let fieldRef: unknown = document;
111+
112+
for (const field of fieldPath) {
113+
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) {
114+
fieldRef = (fieldRef as Record<string, unknown>)[field];
115+
} else {
116+
return true;
117+
}
118+
}
119+
120+
switch (definition.quantization) {
121+
// Because quantization is not defined by the user
122+
// we have to trust them in the format they use.
123+
case "none":
124+
return true;
125+
case "scalar":
126+
case "binary":
127+
if (fieldRef instanceof BSON.Binary) {
128+
try {
129+
const elements = fieldRef.toFloat32Array();
130+
return elements.length === definition.numDimensions;
131+
} catch {
132+
// bits are also supported
133+
try {
134+
const bits = fieldRef.toBits();
135+
return bits.length === definition.numDimensions;
136+
} catch {
137+
return false;
138+
}
139+
}
140+
} else {
141+
if (!Array.isArray(fieldRef)) {
142+
return false;
143+
}
144+
145+
if (fieldRef.length !== definition.numDimensions) {
146+
return false;
147+
}
148+
149+
if (!fieldRef.every((e) => this.isANumber(e))) {
150+
return false;
151+
}
152+
}
153+
154+
break;
155+
}
156+
157+
return true;
158+
}
159+
160+
private isANumber(value: unknown): boolean {
161+
if (typeof value === "number") {
162+
return true;
163+
}
164+
165+
if (
166+
value instanceof BSON.Int32 ||
167+
value instanceof BSON.Decimal128 ||
168+
value instanceof BSON.Double ||
169+
value instanceof BSON.Long
170+
) {
171+
return true;
172+
}
173+
174+
return false;
175+
}
176+
}

src/common/session.ts

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d
1616
import { ErrorCodes, MongoDBError } from "./errors.js";
1717
import type { ExportsManager } from "./exportsManager.js";
1818
import type { Keychain } from "./keychain.js";
19+
import type { VectorSearchEmbeddingsManager } from "./search/vectorSearchEmbeddingsManager.js";
1920

2021
export interface SessionOptions {
2122
apiBaseUrl: string;
@@ -25,6 +26,7 @@ export interface SessionOptions {
2526
exportsManager: ExportsManager;
2627
connectionManager: ConnectionManager;
2728
keychain: Keychain;
29+
vectorSearchEmbeddingsManager: VectorSearchEmbeddingsManager;
2830
}
2931

3032
export type SessionEvents = {
@@ -40,6 +42,7 @@ export class Session extends EventEmitter<SessionEvents> {
4042
readonly connectionManager: ConnectionManager;
4143
readonly apiClient: ApiClient;
4244
readonly keychain: Keychain;
45+
readonly vectorSearchEmbeddingsManager: VectorSearchEmbeddingsManager;
4346

4447
mcpClient?: {
4548
name?: string;
@@ -57,6 +60,7 @@ export class Session extends EventEmitter<SessionEvents> {
5760
connectionManager,
5861
exportsManager,
5962
keychain,
63+
vectorSearchEmbeddingsManager,
6064
}: SessionOptions) {
6165
super();
6266

@@ -73,6 +77,7 @@ export class Session extends EventEmitter<SessionEvents> {
7377
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
7478
this.exportsManager = exportsManager;
7579
this.connectionManager = connectionManager;
80+
this.vectorSearchEmbeddingsManager = vectorSearchEmbeddingsManager;
7681
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
7782
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
7883
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));
@@ -141,13 +146,25 @@ export class Session extends EventEmitter<SessionEvents> {
141146
return this.connectionManager.currentConnectionState.tag === "connected";
142147
}
143148

144-
isSearchSupported(): Promise<boolean> {
149+
async isSearchSupported(): Promise<boolean> {
145150
const state = this.connectionManager.currentConnectionState;
146151
if (state.tag === "connected") {
147-
return state.isSearchSupported();
152+
return await state.isSearchSupported();
148153
}
149154

150-
return Promise.resolve(false);
155+
return false;
156+
}
157+
158+
async assertSearchSupported(): Promise<void> {
159+
const availability = await this.isSearchSupported();
160+
if (!availability) {
161+
throw new MongoDBError(
162+
ErrorCodes.AtlasSearchNotSupported,
163+
"Atlas Search is not supported in the current cluster."
164+
);
165+
}
166+
167+
return;
151168
}
152169

153170
get serviceProvider(): NodeDriverServiceProvider {

src/tools/mongodb/create/createIndex.ts

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { z } from "zod";
22
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
4-
import type { ToolCategory } from "../../tool.js";
54
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
65
import type { IndexDirection } from "mongodb";
76

@@ -113,25 +112,7 @@ export class CreateIndexTool extends MongoDBToolBase {
113112
break;
114113
case "vectorSearch":
115114
{
116-
const isVectorSearchSupported = await this.session.isSearchSupported();
117-
if (!isVectorSearchSupported) {
118-
// TODO: remove hacky casts once we merge the local dev tools
119-
const isLocalAtlasAvailable =
120-
(this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory))
121-
.length ?? 0) > 0;
122-
123-
const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI";
124-
return {
125-
content: [
126-
{
127-
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
128-
type: "text",
129-
},
130-
],
131-
isError: true,
132-
};
133-
}
134-
115+
await this.ensureSearchIsSupported();
135116
indexes = await provider.createSearchIndexes(database, collection, [
136117
{
137118
name,
@@ -144,6 +125,8 @@ export class CreateIndexTool extends MongoDBToolBase {
144125

145126
responseClarification =
146127
" Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status.";
128+
// clean up the embeddings cache so it considers the new index
129+
this.session.vectorSearchEmbeddingsManager.cleanupEmbeddingsForNamespace({ database, collection });
147130
}
148131

149132
break;

0 commit comments

Comments
 (0)