Skip to content

Commit 94bea82

Browse files
committed
Handle multi target vectors in searches
1 parent 4f13fd7 commit 94bea82

File tree

11 files changed

+261
-153
lines changed

11 files changed

+261
-153
lines changed

src/collections/collection/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import query, { Query } from '../query/index.js';
1515
import sort, { Sort } from '../sort/index.js';
1616
import tenants, { TenantInput, Tenants } from '../tenants/index.js';
1717
import { QueryMetadata, QueryProperty, QueryReference } from '../types/index.js';
18+
import multiTargetVector, { MultiTargetVector } from '../vectors/multiTargetVector.js';
1819

1920
export interface Collection<T = undefined, N = string> {
2021
/** This namespace includes all the querying methods available to you when using Weaviate's standard aggregation capabilities. */
@@ -39,6 +40,8 @@ export interface Collection<T = undefined, N = string> {
3940
sort: Sort<T>;
4041
/** This namespace includes all the CRUD methods available to you when modifying the tenants of a multi-tenancy-enabled collection in Weaviate. */
4142
tenants: Tenants;
43+
/** This namespaces includes the methods by which you cna create the `MultiTargetVectorJoin` values for use when performing multi-target vector searches over your collection. */
44+
multiTargetVector: MultiTargetVector;
4245
/**
4346
* Use this method to check if the collection exists in Weaviate.
4447
*
@@ -116,6 +119,7 @@ const collection = <T, N>(
116119
filter: filter<T extends undefined ? any : T>(),
117120
generate: generate<T>(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant),
118121
metrics: metrics<T>(),
122+
multiTargetVector: multiTargetVector(),
119123
name: name,
120124
query: queryCollection,
121125
sort: sort<T>(),

src/collections/generate/index.ts

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ class GenerateManager<T> implements Generate<T> {
7676
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query));
7777
};
7878

79+
private checkSupportForMultiTargetVectorSearch = async (opts?: BaseNearOptions<T>) => {
80+
if (!Serialize.isMultiTargetVector(opts)) return false;
81+
const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch();
82+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
83+
return check.supports;
84+
};
85+
86+
private nearSearch = async (opts?: BaseNearOptions<T>) => {
87+
const [_, supportsTargets] = await Promise.all([
88+
this.checkSupportForNamedVectors(opts),
89+
this.checkSupportForMultiTargetVectorSearch(opts),
90+
]);
91+
return {
92+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
93+
supportsTargets,
94+
};
95+
};
96+
7997
private async parseReply(reply: SearchReply) {
8098
const deserialize = await Deserialize.use(this.dbVersionSupport);
8199
return deserialize.generate<T>(reply);
@@ -150,7 +168,7 @@ class GenerateManager<T> implements Generate<T> {
150168
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
151169
.then((search) =>
152170
search.withHybrid({
153-
...Serialize.hybrid({ query, ...opts }),
171+
...Serialize.hybrid({ query, supportsTargets: false, ...opts }),
154172
generative: Serialize.generative(generate),
155173
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
156174
? Serialize.groupBy(opts.groupBy)
@@ -175,12 +193,11 @@ class GenerateManager<T> implements Generate<T> {
175193
generate: GenerateOptions<T>,
176194
opts?: NearOptions<T>
177195
): GenerateReturn<T> {
178-
return this.checkSupportForNamedVectors(opts)
179-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
180-
.then((search) =>
196+
return this.nearSearch(opts)
197+
.then(({ search, supportsTargets }) =>
181198
toBase64FromMedia(image).then((image) =>
182199
search.withNearImage({
183-
...Serialize.nearImage({ image, ...(opts ? opts : {}) }),
200+
...Serialize.nearImage({ image, supportsTargets, ...(opts ? opts : {}) }),
184201
generative: Serialize.generative(generate),
185202
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
186203
? Serialize.groupBy(opts.groupBy)
@@ -202,11 +219,10 @@ class GenerateManager<T> implements Generate<T> {
202219
opts: GroupByNearOptions<T>
203220
): Promise<GenerativeGroupByReturn<T>>;
204221
public nearObject(id: string, generate: GenerateOptions<T>, opts?: NearOptions<T>): GenerateReturn<T> {
205-
return this.checkSupportForNamedVectors(opts)
206-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
207-
.then((search) =>
222+
return this.nearSearch(opts)
223+
.then(({ search, supportsTargets }) =>
208224
search.withNearObject({
209-
...Serialize.nearObject({ id, ...(opts ? opts : {}) }),
225+
...Serialize.nearObject({ id, supportsTargets, ...(opts ? opts : {}) }),
210226
generative: Serialize.generative(generate),
211227
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
212228
? Serialize.groupBy(opts.groupBy)
@@ -231,11 +247,10 @@ class GenerateManager<T> implements Generate<T> {
231247
generate: GenerateOptions<T>,
232248
opts?: NearOptions<T>
233249
): GenerateReturn<T> {
234-
return this.checkSupportForNamedVectors(opts)
235-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
236-
.then((search) =>
250+
return this.nearSearch(opts)
251+
.then(({ search, supportsTargets }) =>
237252
search.withNearText({
238-
...Serialize.nearText({ query, ...(opts ? opts : {}) }),
253+
...Serialize.nearText({ query, supportsTargets, ...(opts ? opts : {}) }),
239254
generative: Serialize.generative(generate),
240255
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
241256
? Serialize.groupBy(opts.groupBy)
@@ -260,11 +275,10 @@ class GenerateManager<T> implements Generate<T> {
260275
generate: GenerateOptions<T>,
261276
opts?: NearOptions<T>
262277
): GenerateReturn<T> {
263-
return this.checkSupportForNamedVectors(opts)
264-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
265-
.then((search) =>
278+
return this.nearSearch(opts)
279+
.then(({ search, supportsTargets }) =>
266280
search.withNearVector({
267-
...Serialize.nearVector({ vector, ...(opts ? opts : {}) }),
281+
...Serialize.nearVector({ vector, supportsTargets, ...(opts ? opts : {}) }),
268282
generative: Serialize.generative(generate),
269283
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
270284
? Serialize.groupBy(opts.groupBy)
@@ -292,10 +306,10 @@ class GenerateManager<T> implements Generate<T> {
292306
generate: GenerateOptions<T>,
293307
opts?: NearOptions<T>
294308
): GenerateReturn<T> {
295-
return this.checkSupportForNamedVectors(opts)
296-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
297-
.then((search) => {
309+
return this.nearSearch(opts)
310+
.then(({ search, supportsTargets }) => {
298311
let reply: Promise<SearchReply>;
312+
const args = { supportsTargets, ...(opts ? opts : {}) };
299313
const generative = Serialize.generative(generate);
300314
const groupBy = Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
301315
? Serialize.groupBy(opts.groupBy)
@@ -304,7 +318,7 @@ class GenerateManager<T> implements Generate<T> {
304318
case 'audio':
305319
reply = toBase64FromMedia(media).then((media) =>
306320
search.withNearAudio({
307-
...Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) }),
321+
...Serialize.nearAudio({ audio: media, ...args }),
308322
generative,
309323
groupBy,
310324
})
@@ -313,7 +327,7 @@ class GenerateManager<T> implements Generate<T> {
313327
case 'depth':
314328
reply = toBase64FromMedia(media).then((media) =>
315329
search.withNearDepth({
316-
...Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) }),
330+
...Serialize.nearDepth({ depth: media, ...args }),
317331
generative,
318332
groupBy,
319333
})
@@ -322,7 +336,7 @@ class GenerateManager<T> implements Generate<T> {
322336
case 'image':
323337
reply = toBase64FromMedia(media).then((media) =>
324338
search.withNearImage({
325-
...Serialize.nearImage({ image: media, ...(opts ? opts : {}) }),
339+
...Serialize.nearImage({ image: media, ...args }),
326340
generative,
327341
groupBy,
328342
})
@@ -331,7 +345,7 @@ class GenerateManager<T> implements Generate<T> {
331345
case 'imu':
332346
reply = toBase64FromMedia(media).then((media) =>
333347
search.withNearIMU({
334-
...Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) }),
348+
...Serialize.nearIMU({ imu: media, ...args }),
335349
generative,
336350
groupBy,
337351
})
@@ -340,7 +354,7 @@ class GenerateManager<T> implements Generate<T> {
340354
case 'thermal':
341355
reply = toBase64FromMedia(media).then((media) =>
342356
search.withNearThermal({
343-
...Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) }),
357+
...Serialize.nearThermal({ thermal: media, ...args }),
344358
generative,
345359
groupBy,
346360
})
@@ -349,7 +363,7 @@ class GenerateManager<T> implements Generate<T> {
349363
case 'video':
350364
reply = toBase64FromMedia(media).then((media) =>
351365
search.withNearVideo({
352-
...Serialize.nearVideo({ video: media, ...(opts ? opts : {}) }),
366+
...Serialize.nearVideo({ video: media, ...args }),
353367
generative,
354368
groupBy,
355369
})

src/collections/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,4 +306,4 @@ export * from './references/index.js';
306306
export * from './sort/index.js';
307307
export * from './tenants/index.js';
308308
export * from './types/index.js';
309-
export * from './vectors/index.js';
309+
export * from './vectors/multiTargetVector.js';

src/collections/query/index.ts

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ class QueryManager<T> implements Query<T> {
8585
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
8686
};
8787

88+
private checkSupportForMultiTargetVectorSearch = async (opts?: BaseNearOptions<T>) => {
89+
if (!Serialize.isMultiTargetVector(opts)) return false;
90+
const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch();
91+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
92+
return check.supports;
93+
};
94+
95+
private nearSearch = async (opts?: BaseNearOptions<T>) => {
96+
const [_, supportsTargets] = await Promise.all([
97+
this.checkSupportForNamedVectors(opts),
98+
this.checkSupportForMultiTargetVectorSearch(opts),
99+
]);
100+
return {
101+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
102+
supportsTargets,
103+
};
104+
};
105+
88106
private async parseReply(reply: SearchReply) {
89107
const deserialize = await Deserialize.use(this.dbVersionSupport);
90108
return deserialize.query<T>(reply);
@@ -143,7 +161,7 @@ class QueryManager<T> implements Query<T> {
143161
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
144162
.then((search) =>
145163
search.withHybrid({
146-
...Serialize.hybrid({ query, ...opts }),
164+
...Serialize.hybrid({ query, supportsTargets: false, ...opts }),
147165
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
148166
? Serialize.groupBy(opts.groupBy)
149167
: undefined,
@@ -155,12 +173,11 @@ class QueryManager<T> implements Query<T> {
155173
public nearImage(image: string | Buffer, opts?: BaseNearOptions<T>): Promise<WeaviateReturn<T>>;
156174
public nearImage(image: string | Buffer, opts: GroupByNearOptions<T>): Promise<GroupByReturn<T>>;
157175
public nearImage(image: string | Buffer, opts?: NearOptions<T>): QueryReturn<T> {
158-
return this.checkSupportForNamedVectors(opts)
159-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
160-
.then((search) => {
176+
return this.nearSearch(opts)
177+
.then(({ search, supportsTargets }) => {
161178
return toBase64FromMedia(image).then((image) =>
162179
search.withNearImage({
163-
...Serialize.nearImage({ image, ...(opts ? opts : {}) }),
180+
...Serialize.nearImage({ image, supportsTargets, ...(opts ? opts : {}) }),
164181
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
165182
? Serialize.groupBy(opts.groupBy)
166183
: undefined,
@@ -181,39 +198,39 @@ class QueryManager<T> implements Query<T> {
181198
opts: GroupByNearOptions<T>
182199
): Promise<GroupByReturn<T>>;
183200
public nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions<T>): QueryReturn<T> {
184-
return this.checkSupportForNamedVectors(opts)
185-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
186-
.then((search) => {
201+
return this.nearSearch(opts)
202+
.then(({ search, supportsTargets }) => {
203+
const args = { supportsTargets, ...(opts ? opts : {}) };
187204
let reply: Promise<SearchReply>;
188205
switch (type) {
189206
case 'audio':
190207
reply = toBase64FromMedia(media).then((media) =>
191-
search.withNearAudio(Serialize.nearAudio({ audio: media, ...(opts ? opts : {}) }))
208+
search.withNearAudio(Serialize.nearAudio({ audio: media, ...args }))
192209
);
193210
break;
194211
case 'depth':
195212
reply = toBase64FromMedia(media).then((media) =>
196-
search.withNearDepth(Serialize.nearDepth({ depth: media, ...(opts ? opts : {}) }))
213+
search.withNearDepth(Serialize.nearDepth({ depth: media, ...args }))
197214
);
198215
break;
199216
case 'image':
200217
reply = toBase64FromMedia(media).then((media) =>
201-
search.withNearImage(Serialize.nearImage({ image: media, ...(opts ? opts : {}) }))
218+
search.withNearImage(Serialize.nearImage({ image: media, ...args }))
202219
);
203220
break;
204221
case 'imu':
205222
reply = toBase64FromMedia(media).then((media) =>
206-
search.withNearIMU(Serialize.nearIMU({ imu: media, ...(opts ? opts : {}) }))
223+
search.withNearIMU(Serialize.nearIMU({ imu: media, ...args }))
207224
);
208225
break;
209226
case 'thermal':
210227
reply = toBase64FromMedia(media).then((media) =>
211-
search.withNearThermal(Serialize.nearThermal({ thermal: media, ...(opts ? opts : {}) }))
228+
search.withNearThermal(Serialize.nearThermal({ thermal: media, ...args }))
212229
);
213230
break;
214231
case 'video':
215232
reply = toBase64FromMedia(media).then((media) =>
216-
search.withNearVideo(Serialize.nearVideo({ video: media, ...(opts ? opts : {}) }))
233+
search.withNearVideo(Serialize.nearVideo({ video: media, ...args }))
217234
);
218235
break;
219236
default:
@@ -227,11 +244,10 @@ class QueryManager<T> implements Query<T> {
227244
public nearObject(id: string, opts?: BaseNearOptions<T>): Promise<WeaviateReturn<T>>;
228245
public nearObject(id: string, opts: GroupByNearOptions<T>): Promise<GroupByReturn<T>>;
229246
public nearObject(id: string, opts?: NearOptions<T>): QueryReturn<T> {
230-
return this.checkSupportForNamedVectors(opts)
231-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
232-
.then((search) =>
247+
return this.nearSearch(opts)
248+
.then(({ search, supportsTargets }) =>
233249
search.withNearObject({
234-
...Serialize.nearObject({ id, ...(opts ? opts : {}) }),
250+
...Serialize.nearObject({ id, supportsTargets, ...(opts ? opts : {}) }),
235251
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
236252
? Serialize.groupBy(opts.groupBy)
237253
: undefined,
@@ -243,11 +259,10 @@ class QueryManager<T> implements Query<T> {
243259
public nearText(query: string | string[], opts?: BaseNearTextOptions<T>): Promise<WeaviateReturn<T>>;
244260
public nearText(query: string | string[], opts: GroupByNearTextOptions<T>): Promise<GroupByReturn<T>>;
245261
public nearText(query: string | string[], opts?: NearTextOptions<T>): QueryReturn<T> {
246-
return this.checkSupportForNamedVectors(opts)
247-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
248-
.then((search) =>
262+
return this.nearSearch(opts)
263+
.then(({ search, supportsTargets }) =>
249264
search.withNearText({
250-
...Serialize.nearText({ query, ...(opts ? opts : {}) }),
265+
...Serialize.nearText({ query, supportsTargets, ...(opts ? opts : {}) }),
251266
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
252267
? Serialize.groupBy(opts.groupBy)
253268
: undefined,
@@ -259,11 +274,10 @@ class QueryManager<T> implements Query<T> {
259274
public nearVector(vector: NearVectorInputType, opts?: BaseNearOptions<T>): Promise<WeaviateReturn<T>>;
260275
public nearVector(vector: NearVectorInputType, opts: GroupByNearOptions<T>): Promise<GroupByReturn<T>>;
261276
public nearVector(vector: NearVectorInputType, opts?: NearOptions<T>): QueryReturn<T> {
262-
return this.checkSupportForNamedVectors(opts)
263-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
264-
.then((search) =>
277+
return this.nearSearch(opts)
278+
.then(({ search, supportsTargets }) =>
265279
search.withNearVector({
266-
...Serialize.nearVector({ vector, ...(opts ? opts : {}) }),
280+
...Serialize.nearVector({ vector, supportsTargets, ...(opts ? opts : {}) }),
267281
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
268282
? Serialize.groupBy(opts.groupBy)
269283
: undefined,

0 commit comments

Comments
 (0)