Skip to content

Commit 92dbaae

Browse files
committed
Fix BC for queries, add generative support, at hybrid support, use latest preview image
1 parent 4f6c048 commit 92dbaae

File tree

8 files changed

+389
-124
lines changed

8 files changed

+389
-124
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
env:
1010
WEAVIATE_124: 1.24.19
1111
WEAVIATE_125: 1.25.5
12-
WEAVIATE_126: preview--4e2eb3a
12+
WEAVIATE_126: preview--c2bfc40
1313

1414
jobs:
1515
checks:

src/collections/generate/index.ts

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ class GenerateManager<T> implements Generate<T> {
7676
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query));
7777
};
7878

79+
private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions<T>) => {
80+
if (opts?.vector === undefined || Array.isArray(opts.vector)) return;
81+
const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries();
82+
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
83+
};
84+
7985
private checkSupportForMultiTargetVectorSearch = async (opts?: BaseNearOptions<T>) => {
8086
if (!Serialize.isMultiTargetVector(opts)) return false;
8187
const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch();
@@ -94,6 +100,19 @@ class GenerateManager<T> implements Generate<T> {
94100
};
95101
};
96102

103+
private hybridSearch = async (opts?: BaseHybridOptions<T>) => {
104+
const [supportsTargets] = await Promise.all([
105+
this.checkSupportForMultiTargetVectorSearch(opts),
106+
this.checkSupportForNamedVectors(opts),
107+
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
108+
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
109+
]);
110+
return {
111+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
112+
supportsTargets,
113+
};
114+
};
115+
97116
private async parseReply(reply: SearchReply) {
98117
const deserialize = await Deserialize.use(this.dbVersionSupport);
99118
return deserialize.generate<T>(reply);
@@ -161,14 +180,10 @@ class GenerateManager<T> implements Generate<T> {
161180
opts: GroupByHybridOptions<T>
162181
): Promise<GenerativeGroupByReturn<T>>;
163182
public hybrid(query: string, generate: GenerateOptions<T>, opts?: HybridOptions<T>): GenerateReturn<T> {
164-
return Promise.all([
165-
this.checkSupportForNamedVectors(opts),
166-
this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts),
167-
])
168-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
169-
.then((search) =>
183+
return this.hybridSearch(opts)
184+
.then(({ search, supportsTargets }) =>
170185
search.withHybrid({
171-
...Serialize.hybrid({ query, supportsTargets: false, ...opts }),
186+
...Serialize.hybrid({ query, supportsTargets, ...opts }),
172187
generative: Serialize.generative(generate),
173188
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
174189
? Serialize.groupBy(opts.groupBy)

src/collections/generate/integration.test.ts

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ import { GenerateOptions, GroupByOptions } from '../types/index.js';
77

88
const maybe = process.env.OPENAI_APIKEY ? describe : describe.skip;
99

10+
const makeOpenAIClient = () =>
11+
weaviate.connectToLocal({
12+
port: 8086,
13+
grpcPort: 50057,
14+
headers: {
15+
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
16+
},
17+
});
18+
1019
maybe('Testing of the collection.generate methods with a simple collection', () => {
1120
let client: WeaviateClient;
1221
let collection: Collection<TestCollectionGenerateSimple, 'TestCollectionGenerateSimple'>;
@@ -32,13 +41,7 @@ maybe('Testing of the collection.generate methods with a simple collection', ()
3241
});
3342

3443
beforeAll(async () => {
35-
client = await weaviate.connectToLocal({
36-
port: 8086,
37-
grpcPort: 50057,
38-
headers: {
39-
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
40-
},
41-
});
44+
client = await makeOpenAIClient();
4245
collection = client.collections.get(collectionName);
4346
id = await client.collections
4447
.create({
@@ -179,13 +182,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti
179182
});
180183

181184
beforeAll(async () => {
182-
client = await weaviate.connectToLocal({
183-
port: 8086,
184-
grpcPort: 50057,
185-
headers: {
186-
'X-Openai-Api-Key': process.env.OPENAI_APIKEY!,
187-
},
188-
});
185+
client = await makeOpenAIClient();
189186
collection = client.collections.get(collectionName);
190187
id = await client.collections
191188
.create({
@@ -314,3 +311,116 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti
314311
expect(ret.objects[0].belongsToGroup).toEqual('test');
315312
});
316313
});
314+
315+
maybe('Testing of the collection.generate methods with a multi vector collection', () => {
316+
let client: WeaviateClient;
317+
let collection: Collection;
318+
const collectionName = 'TestCollectionQueryWithMultiVector';
319+
320+
let id1: string;
321+
let id2: string;
322+
let titleVector: number[];
323+
let title2Vector: number[];
324+
325+
afterAll(() => {
326+
return client.collections.delete(collectionName).catch((err) => {
327+
console.error(err);
328+
throw err;
329+
});
330+
});
331+
332+
beforeAll(async () => {
333+
client = await makeOpenAIClient();
334+
collection = client.collections.get(collectionName);
335+
const query = () =>
336+
client.collections
337+
.create({
338+
name: collectionName,
339+
properties: [
340+
{
341+
name: 'title',
342+
dataType: 'text',
343+
vectorizePropertyName: false,
344+
},
345+
],
346+
vectorizers: [
347+
weaviate.configure.vectorizer.text2VecOpenAI({
348+
name: 'title',
349+
sourceProperties: ['title'],
350+
}),
351+
weaviate.configure.vectorizer.text2VecOpenAI({
352+
name: 'title2',
353+
sourceProperties: ['title'],
354+
}),
355+
],
356+
})
357+
.then(async () => {
358+
id1 = await collection.data.insert({
359+
properties: {
360+
title: 'test',
361+
},
362+
});
363+
id2 = await collection.data.insert({
364+
properties: {
365+
title: 'other',
366+
},
367+
});
368+
const res = await collection.query.fetchObjectById(id1, { includeVector: true });
369+
titleVector = res!.vectors.title!;
370+
title2Vector = res!.vectors.title2!;
371+
});
372+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) {
373+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
374+
return;
375+
}
376+
return query();
377+
});
378+
379+
it('should generate with a near vector search on multi vectors', async () => {
380+
const query = () =>
381+
collection.generate.nearVector(
382+
[titleVector, title2Vector],
383+
{
384+
groupedTask: 'What is the value of title here? {title}',
385+
groupedProperties: ['title'],
386+
singlePrompt: 'Write a haiku about ducks for {title}',
387+
},
388+
{
389+
targetVector: ['title', 'title2'],
390+
}
391+
);
392+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) {
393+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
394+
return;
395+
}
396+
const ret = await query();
397+
expect(ret.objects.length).toEqual(2);
398+
expect(ret.generated).toBeDefined();
399+
expect(ret.objects[0].generated).toBeDefined();
400+
expect(ret.objects[1].generated).toBeDefined();
401+
});
402+
403+
it('should generate with a near vector search on multi vectors', async () => {
404+
const query = () =>
405+
collection.generate.nearVector(
406+
{ title: titleVector, title2: title2Vector },
407+
{
408+
groupedTask: 'What is the value of title here? {title}',
409+
groupedProperties: ['title'],
410+
singlePrompt: 'Write a haiku about ducks for {title}',
411+
},
412+
{
413+
targetVector: ['title', 'title2'],
414+
}
415+
);
416+
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) {
417+
await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError);
418+
return;
419+
}
420+
const ret = await query();
421+
expect(ret.objects.length).toEqual(2);
422+
expect(ret.generated).toBeDefined();
423+
expect(ret.objects[0].generated).toBeDefined();
424+
expect(ret.objects[1].generated).toBeDefined();
425+
});
426+
});

src/collections/generate/types.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {
1313
NearMediaType,
1414
NearOptions,
1515
NearTextOptions,
16+
NearVectorInputType,
1617
} from '../query/types.js';
1718
import {
1819
GenerateOptions,
@@ -301,13 +302,13 @@ interface NearVector<T> {
301302
*
302303
* This overload is for performing a search without the `groupBy` param.
303304
*
304-
* @param {number[]} vector - The vector to search for.
305+
* @param {NearVectorInputType} vector - The vector(s) to search for.
305306
* @param {GenerateOptions<T>} generate - The available options for performing the generation.
306307
* @param {BaseNearOptions<T>} [opts] - The available options for performing the near-vector search.
307308
* @return {Promise<GenerativeReturn<T>>} - The results of the search including the generated data.
308309
*/
309310
nearVector(
310-
vector: number[],
311+
vector: NearVectorInputType,
311312
generate: GenerateOptions<T>,
312313
opts?: BaseNearOptions<T>
313314
): Promise<GenerativeReturn<T>>;
@@ -318,13 +319,13 @@ interface NearVector<T> {
318319
*
319320
* This overload is for performing a search with the `groupBy` param.
320321
*
321-
* @param {number[]} vector - The vector to search for.
322+
* @param {NearVectorInputType} vector - The vector(s) to search for.
322323
* @param {GenerateOptions<T>} generate - The available options for performing the generation.
323324
* @param {GroupByNearOptions<T>} opts - The available options for performing the near-vector search.
324325
* @return {Promise<GenerativeGroupByReturn<T>>} - The results of the search including the generated data grouped by the specified properties.
325326
*/
326327
nearVector(
327-
vector: number[],
328+
vector: NearVectorInputType,
328329
generate: GenerateOptions<T>,
329330
opts: GroupByNearOptions<T>
330331
): Promise<GenerativeGroupByReturn<T>>;
@@ -335,12 +336,16 @@ interface NearVector<T> {
335336
*
336337
* This overload is for performing a search with a programmatically defined `opts` param.
337338
*
338-
* @param {number[]} vector - The vector to search for.
339+
* @param {NearVectorInputType} vector - The vector(s) to search for.
339340
* @param {GenerateOptions<T>} generate - The available options for performing the generation.
340341
* @param {NearOptions<T>} [opts] - The available options for performing the near-vector search.
341342
* @return {GenerateReturn<T>} - The results of the search including the generated data.
342343
*/
343-
nearVector(vector: number[], generate: GenerateOptions<T>, opts?: NearOptions<T>): GenerateReturn<T>;
344+
nearVector(
345+
vector: NearVectorInputType,
346+
generate: GenerateOptions<T>,
347+
opts?: NearOptions<T>
348+
): GenerateReturn<T>;
344349
}
345350

346351
export interface Generate<T>

src/collections/query/index.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,22 @@ class QueryManager<T> implements Query<T> {
9393
};
9494

9595
private nearSearch = async (opts?: BaseNearOptions<T>) => {
96-
const [_, supportsTargets] = await Promise.all([
96+
const [supportsTargets] = await Promise.all([
97+
this.checkSupportForMultiTargetVectorSearch(opts),
9798
this.checkSupportForNamedVectors(opts),
99+
]);
100+
return {
101+
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
102+
supportsTargets,
103+
};
104+
};
105+
106+
private hybridSearch = async (opts?: BaseHybridOptions<T>) => {
107+
const [supportsTargets] = await Promise.all([
98108
this.checkSupportForMultiTargetVectorSearch(opts),
109+
this.checkSupportForNamedVectors(opts),
110+
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
111+
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
99112
]);
100113
return {
101114
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
@@ -153,15 +166,10 @@ class QueryManager<T> implements Query<T> {
153166
public hybrid(query: string, opts?: BaseHybridOptions<T>): Promise<WeaviateReturn<T>>;
154167
public hybrid(query: string, opts: GroupByHybridOptions<T>): Promise<GroupByReturn<T>>;
155168
public hybrid(query: string, opts?: HybridOptions<T>): QueryReturn<T> {
156-
return Promise.all([
157-
this.checkSupportForNamedVectors(opts),
158-
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
159-
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
160-
])
161-
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
162-
.then((search) =>
169+
return this.hybridSearch(opts)
170+
.then(({ search, supportsTargets }) =>
163171
search.withHybrid({
164-
...Serialize.hybrid({ query, supportsTargets: false, ...opts }),
172+
...Serialize.hybrid({ query, supportsTargets, ...opts }),
165173
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
166174
? Serialize.groupBy(opts.groupBy)
167175
: undefined,

0 commit comments

Comments
 (0)