Skip to content

Commit 2c992c5

Browse files
authored
Merge pull request #163 from weaviate/dev/support-multi-vector-search
DevV3/support multi vector search
2 parents 04f9235 + e1deecf commit 2c992c5

File tree

14 files changed

+762
-155
lines changed

14 files changed

+762
-155
lines changed

src/collections/collection/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import query, { Query } from '../query/index.js';
1515
import sort, { Sort } from '../sort/index.js';
1616
import tenants, { TenantBase, Tenants } from '../tenants/index.js';
1717
import { QueryMetadata, QueryProperty, QueryReference } from '../types/index.js';
18+
import multiTargetVector, { MultiTargetVector } from '../vectors/multiTargetVector.js';
1819

1920
export interface Collection<T = undefined, N = string> {
2021
/** This namespace includes all the querying methods available to you when using Weaviate's standard aggregation capabilities. */
@@ -39,6 +40,8 @@ export interface Collection<T = undefined, N = string> {
3940
sort: Sort<T>;
4041
/** This namespace includes all the CRUD methods available to you when modifying the tenants of a multi-tenancy-enabled collection in Weaviate. */
4142
tenants: Tenants;
43+
/** This namespaces includes the methods by which you cna create the `MultiTargetVectorJoin` values for use when performing multi-target vector searches over your collection. */
44+
multiTargetVector: MultiTargetVector;
4245
/**
4346
* Use this method to check if the collection exists in Weaviate.
4447
*
@@ -117,6 +120,7 @@ const collection = <T, N>(
117120
filter: filter<T extends undefined ? any : T>(),
118121
generate: generate<T>(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant),
119122
metrics: metrics<T>(),
123+
multiTargetVector: multiTargetVector(),
120124
name: name,
121125
query: queryCollection,
122126
sort: sort<T>(),

src/collections/generate/index.ts

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,43 @@ class GenerateManager<T> implements Generate<T> {
7676
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query));
7777
};
7878

79+
private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions<T>) => {
80+
if (opts?.vector === undefined || Array.isArray(opts.vector)) return;
81+
const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries();
82+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
83+
};
84+
85+
private checkSupportForMultiTargetVectorSearch = async (opts?: BaseNearOptions<T>) => {
86+
if (!Serialize.isMultiTargetVector(opts)) return false;
87+
const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch();
88+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
89+
return check.supports;
90+
};
91+
92+
private nearSearch = async (opts?: BaseNearOptions<T>) => {
93+
const [_, supportsTargets] = await Promise.all([
94+
this.checkSupportForNamedVectors(opts),
95+
this.checkSupportForMultiTargetVectorSearch(opts),
96+
]);
97+
return {
98+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
99+
supportsTargets,
100+
};
101+
};
102+
103+
private hybridSearch = async (opts?: BaseHybridOptions<T>) => {
104+
const [supportsTargets] = await Promise.all([
105+
this.checkSupportForMultiTargetVectorSearch(opts),
106+
this.checkSupportForNamedVectors(opts),
107+
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
108+
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
109+
]);
110+
return {
111+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
112+
supportsTargets,
113+
};
114+
};
115+
79116
private async parseReply(reply: SearchReply) {
80117
const deserialize = await Deserialize.use(this.dbVersionSupport);
81118
return deserialize.generate<T>(reply);
@@ -143,14 +180,10 @@ class GenerateManager<T> implements Generate<T> {
143180
opts: GroupByHybridOptions<T>
144181
): Promise<GenerativeGroupByReturn<T>>;
145182
public hybrid(query: string, generate: GenerateOptions<T>, opts?: HybridOptions<T>): GenerateReturn<T> {
146-
return Promise.all([
147-
this.checkSupportForNamedVectors(opts),
148-
this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts),
149-
])
150-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
151-
.then((search) =>
183+
return this.hybridSearch(opts)
184+
.then(({ search, supportsTargets }) =>
152185
search.withHybrid({
153-
...Serialize.hybrid({ query, ...opts }),
186+
...Serialize.hybrid({ query, supportsTargets, ...opts }),
154187
generative: Serialize.generative(generate),
155188
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
156189
? Serialize.groupBy(opts.groupBy)
@@ -175,12 +208,11 @@ class GenerateManager<T> implements Generate<T> {
175208
generate: GenerateOptions<T>,
176209
opts?: NearOptions<T>
177210
): GenerateReturn<T> {
178-
return this.checkSupportForNamedVectors(opts)
179-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
180-
.then((search) =>
211+
return this.nearSearch(opts)
212+
.then(({ search, supportsTargets }) =>
181213
toBase64FromMedia(image).then((image) =>
182214
search.withNearImage({
183-
...Serialize.nearImage({ image, ...(opts ? opts : {}) }),
215+
...Serialize.nearImage({ image, supportsTargets, ...(opts ? opts : {}) }),
184216
generative: Serialize.generative(generate),
185217
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
186218
? Serialize.groupBy(opts.groupBy)
@@ -202,11 +234,10 @@ class GenerateManager<T> implements Generate<T> {
202234
opts: GroupByNearOptions<T>
203235
): Promise<GenerativeGroupByReturn<T>>;
204236
public nearObject(id: string, generate: GenerateOptions<T>, opts?: NearOptions<T>): GenerateReturn<T> {
205-
return this.checkSupportForNamedVectors(opts)
206-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
207-
.then((search) =>
237+
return this.nearSearch(opts)
238+
.then(({ search, supportsTargets }) =>
208239
search.withNearObject({
209-
...Serialize.nearObject({ id, ...(opts ? opts : {}) }),
240+
...Serialize.nearObject({ id, supportsTargets, ...(opts ? opts : {}) }),
210241
generative: Serialize.generative(generate),
211242
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
212243
? Serialize.groupBy(opts.groupBy)
@@ -231,11 +262,10 @@ class GenerateManager<T> implements Generate<T> {
231262
generate: GenerateOptions<T>,
232263
opts?: NearOptions<T>
233264
): GenerateReturn<T> {
234-
return this.checkSupportForNamedVectors(opts)
235-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
236-
.then((search) =>
265+
return this.nearSearch(opts)
266+
.then(({ search, supportsTargets }) =>
237267
search.withNearText({
238-
...Serialize.nearText({ query, ...(opts ? opts : {}) }),
268+
...Serialize.nearText({ query, supportsTargets, ...(opts ? opts : {}) }),
239269
generative: Serialize.generative(generate),
240270
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
241271
? Serialize.groupBy(opts.groupBy)
@@ -260,11 +290,10 @@ class GenerateManager<T> implements Generate<T> {
260290
generate: GenerateOptions<T>,
261291
opts?: NearOptions<T>
262292
): GenerateReturn<T> {
263-
return this.checkSupportForNamedVectors(opts)
264-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
265-
.then((search) =>
293+
return this.nearSearch(opts)
294+
.then(({ search, supportsTargets }) =>
266295
search.withNearVector({
267-
...Serialize.nearVector({ vector, ...(opts ? opts : {}) }),
296+
...Serialize.nearVector({ vector, supportsTargets, ...(opts ? opts : {}) }),
268297
generative: Serialize.generative(generate),
269298
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
270299
? Serialize.groupBy(opts.groupBy)
@@ -292,10 +321,10 @@ class GenerateManager<T> implements Generate<T> {
292321
generate: GenerateOptions<T>,
293322
opts?: NearOptions<T>
294323
): GenerateReturn<T> {
295-
return this.checkSupportForNamedVectors(opts)
296-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
297-
.then((search) => {
324+
return this.nearSearch(opts)
325+
.then(({ search, supportsTargets }) => {
298326
let reply: Promise<SearchReply>;
327+
const args = { supportsTargets, ...(opts ? opts : {}) };
299328
const generative = Serialize.generative(generate);
300329
const groupBy = Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
301330
? Serialize.groupBy(opts.groupBy)
@@ -304,7 +333,7 @@ class GenerateManager<T> implements Generate<T> {
304333
case 'audio':
305334
reply = toBase64FromMedia(media).then((media) =>
306335
search.withNearAudio({
307-
...Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) }),
336+
...Serialize.nearAudio({ audio: media, ...args }),
308337
generative,
309338
groupBy,
310339
})
@@ -313,7 +342,7 @@ class GenerateManager<T> implements Generate<T> {
313342
case 'depth':
314343
reply = toBase64FromMedia(media).then((media) =>
315344
search.withNearDepth({
316-
...Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) }),
345+
...Serialize.nearDepth({ depth: media, ...args }),
317346
generative,
318347
groupBy,
319348
})
@@ -322,7 +351,7 @@ class GenerateManager<T> implements Generate<T> {
322351
case 'image':
323352
reply = toBase64FromMedia(media).then((media) =>
324353
search.withNearImage({
325-
...Serialize.nearImage({ image: media, ...(opts ? opts : {}) }),
354+
...Serialize.nearImage({ image: media, ...args }),
326355
generative,
327356
groupBy,
328357
})
@@ -331,7 +360,7 @@ class GenerateManager<T> implements Generate<T> {
331360
case 'imu':
332361
reply = toBase64FromMedia(media).then((media) =>
333362
search.withNearIMU({
334-
...Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) }),
363+
...Serialize.nearIMU({ imu: media, ...args }),
335364
generative,
336365
groupBy,
337366
})
@@ -340,7 +369,7 @@ class GenerateManager<T> implements Generate<T> {
340369
case 'thermal':
341370
reply = toBase64FromMedia(media).then((media) =>
342371
search.withNearThermal({
343-
...Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) }),
372+
...Serialize.nearThermal({ thermal: media, ...args }),
344373
generative,
345374
groupBy,
346375
})
@@ -349,7 +378,7 @@ class GenerateManager<T> implements Generate<T> {
349378
case 'video':
350379
reply = toBase64FromMedia(media).then((media) =>
351380
search.withNearVideo({
352-
...Serialize.nearVideo({ video: media, ...(opts ? opts : {}) }),
381+
...Serialize.nearVideo({ video: media, ...args }),
353382
generative,
354383
groupBy,
355384
})

src/collections/generate/integration.test.ts

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ import { GenerateOptions, GroupByOptions } from '../types/index.js';
77

88
const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip;
99

10+
const makeOpenAIClient = () =>
11+
weaviate.connectToLocal({
12+
port: 8086,
13+
grpcPort: 50057,
14+
headers: {
15+
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
16+
},
17+
});
18+
1019
maybe('Testing of the collection.generate methods with a simple collection', () => {
1120
let client: WeaviateClient;
1221
let collection: Collection<TestCollectionGenerateSimple, 'TestCollectionGenerateSimple'>;
@@ -32,13 +41,7 @@ maybe('Testing of the collection.generate methods with a simple collection', ()
3241
});
3342

3443
beforeAll(async () => {
35-
client = await weaviate.connectToLocal({
36-
port: 8086,
37-
grpcPort: 50057,
38-
headers: {
39-
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
40-
},
41-
});
44+
client = await makeOpenAIClient();
4245
collection = client.collections.get(collectionName);
4346
id = await client.collections
4447
.create({
@@ -179,13 +182,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti
179182
});
180183

181184
beforeAll(async () => {
182-
client = await weaviate.connectToLocal({
183-
port: 8086,
184-
grpcPort: 50057,
185-
headers: {
186-
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
187-
},
188-
});
185+
client = await makeOpenAIClient();
189186
collection = client.collections.get(collectionName);
190187
id = await client.collections
191188
.create({
@@ -314,3 +311,116 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti
314311
expect(ret.objects[0].belongsToGroup).toEqual('test');
315312
});
316313
});
314+
315+
maybe('Testing of the collection.generate methods with a multi vector collection', () => {
316+
let client: WeaviateClient;
317+
let collection: Collection;
318+
const collectionName = 'TestCollectionQueryWithMultiVector';
319+
320+
let id1: string;
321+
let id2: string;
322+
let titleVector: number[];
323+
let title2Vector: number[];
324+
325+
afterAll(() => {
326+
return client.collections.delete(collectionName).catch((err) => {
327+
console.error(err);
328+
throw err;
329+
});
330+
});
331+
332+
beforeAll(async () => {
333+
client = await makeOpenAIClient();
334+
collection = client.collections.get(collectionName);
335+
const query = () =>
336+
client.collections
337+
.create({
338+
name: collectionName,
339+
properties: [
340+
{
341+
name: 'title',
342+
dataType: 'text',
343+
vectorizePropertyName: false,
344+
},
345+
],
346+
vectorizers: [
347+
weaviate.configure.vectorizer.text2VecOpenAI({
348+
name: 'title',
349+
sourceProperties: ['title'],
350+
}),
351+
weaviate.configure.vectorizer.text2VecOpenAI({
352+
name: 'title2',
353+
sourceProperties: ['title'],
354+
}),
355+
],
356+
})
357+
.then(async () => {
358+
id1 = await collection.data.insert({
359+
properties: {
360+
title: 'test',
361+
},
362+
});
363+
id2 = await collection.data.insert({
364+
properties: {
365+
title: 'other',
366+
},
367+
});
368+
const res = await collection.query.fetchObjectById(id1, { includeVector: true });
369+
titleVector = res!.vectors.title!;
370+
title2Vector = res!.vectors.title2!;
371+
});
372+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) {
373+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
374+
return;
375+
}
376+
return query();
377+
});
378+
379+
it('should generate with a near vector search on multi vectors', async () => {
380+
const query = () =>
381+
collection.generate.nearVector(
382+
[titleVector, title2Vector],
383+
{
384+
groupedTask: 'What is the value of title here? {title}',
385+
groupedProperties: ['title'],
386+
singlePrompt: 'Write a haiku about ducks for {title}',
387+
},
388+
{
389+
targetVector: ['title', 'title2'],
390+
}
391+
);
392+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) {
393+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
394+
return;
395+
}
396+
const ret = await query();
397+
expect(ret.objects.length).toEqual(2);
398+
expect(ret.generated).toBeDefined();
399+
expect(ret.objects[0].generated).toBeDefined();
400+
expect(ret.objects[1].generated).toBeDefined();
401+
});
402+
403+
it('should generate with a near vector search on multi vectors', async () => {
404+
const query = () =>
405+
collection.generate.nearVector(
406+
{ title: titleVector, title2: title2Vector },
407+
{
408+
groupedTask: 'What is the value of title here? {title}',
409+
groupedProperties: ['title'],
410+
singlePrompt: 'Write a haiku about ducks for {title}',
411+
},
412+
{
413+
targetVector: ['title', 'title2'],
414+
}
415+
);
416+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) {
417+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
418+
return;
419+
}
420+
const ret = await query();
421+
expect(ret.objects.length).toEqual(2);
422+
expect(ret.generated).toBeDefined();
423+
expect(ret.objects[0].generated).toBeDefined();
424+
expect(ret.objects[1].generated).toBeDefined();
425+
});
426+
});

0 commit comments

Comments
 (0)