Skip to content

Commit 64ab351

Browse files
authored
Merge pull request #104 from semi-technologies/bm25_hybrid
Add Bm25 search and hybrid search
2 parents 2e8726f + 3cd392f commit 64ab351

File tree

12 files changed

+511
-10
lines changed

12 files changed

+511
-10
lines changed

ci/docker-compose-okta.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ services:
1010
- --scheme
1111
- http
1212
- --write-timeout=600s
13-
image: semitechnologies/weaviate:1.15.4-b7811d4
13+
image: semitechnologies/weaviate:1.17.0-prealpha-8950d6f
1414
ports:
1515
- 8082:8082
1616
restart: on-failure:0

ci/docker-compose-wcs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ services:
1010
- --scheme
1111
- http
1212
- --write-timeout=600s
13-
image: semitechnologies/weaviate:preview-replace-shardingconfig-replicas-with-replication-factor-29e987d
13+
image: semitechnologies/weaviate:1.17.0-prealpha-8950d6f
1414
ports:
1515
- 8083:8083
1616
restart: on-failure:0

ci/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
version: '3.4'
33
services:
44
weaviate:
5-
image: semitechnologies/weaviate:preview-replace-shardingconfig-replicas-with-replication-factor-29e987d
5+
image: semitechnologies/weaviate:1.17.0-prealpha-8950d6f
66
restart: on-failure:0
77
ports:
88
- "8080:8080"

cluster/journey.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const weaviate = require("../index");
22
const { createTestFoodSchemaAndData, cleanupTestFood, PIZZA_CLASS_NAME, SOUP_CLASS_NAME } = require("../utils/testData");
33

44
const EXPECTED_WEAVIATE_VERSION = "1.17.0-prealpha"
5-
const EXPECTED_WEAVIATE_GIT_HASH = "29e987d"
5+
const EXPECTED_WEAVIATE_GIT_HASH = "8950d6f"
66

77
describe("cluster nodes endpoint", () => {
88
const client = weaviate.client({

graphql/bm25.js

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import { isValidStringArray, isValidStringProperty } from "../validation/string";
2+
3+
export default class GraphQLBm25 {
4+
5+
constructor(bm25Obj) {
6+
this.source = bm25Obj;
7+
}
8+
9+
toString() {
10+
this.parse();
11+
this.validate();
12+
13+
let args = [`query:${JSON.stringify(this.query)}`]; // query must always be set
14+
15+
if (this.properties !== undefined) {
16+
args = [...args, `properties:${JSON.stringify(this.properties)}`];
17+
}
18+
19+
return `{${args.join(",")}}`;
20+
}
21+
22+
parse() {
23+
for (let key in this.source) {
24+
switch (key) {
25+
case "query":
26+
this.parseQuery(this.source[key]);
27+
break;
28+
case "properties":
29+
this.parseProperties(this.source[key]);
30+
break;
31+
default:
32+
throw new Error(`bm25 filter: unrecognized key '${key}'`);
33+
}
34+
}
35+
}
36+
37+
parseQuery(query) {
38+
if (!isValidStringProperty(query)) {
39+
throw new Error("bm25 filter: query must be a string");
40+
}
41+
42+
this.query = query;
43+
}
44+
45+
parseProperties(properties) {
46+
if (!isValidStringArray(properties)) {
47+
throw new Error("bm25 filter: properties must be an array of strings");
48+
}
49+
50+
this.properties = properties;
51+
}
52+
53+
validate() {
54+
if (!this.query) {
55+
throw new Error("bm25 filter: query cannot be empty");
56+
}
57+
}
58+
}

graphql/getter.js

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import Where from "./where";
22
import NearText from "./nearText";
33
import NearVector from "./nearVector";
4+
import Bm25 from "./bm25";
5+
import Hybrid from "./hybrid";
46
import NearObject from "./nearObject";
57
import NearImage from "./nearImage";
68
import Ask from "./ask";
@@ -60,6 +62,27 @@ export default class Getter {
6062
return this;
6163
};
6264

65+
withBm25 = (bm25Obj) => {
66+
try {
67+
this.bm25String = new Bm25(bm25Obj).toString();
68+
} catch (e) {
69+
this.errors = [...this.errors, e];
70+
}
71+
72+
return this;
73+
};
74+
75+
withHybrid = (hybridObj) => {
76+
try {
77+
this.hybridString = new Hybrid(hybridObj).toString();
78+
} catch (e) {
79+
this.errors = [...this.errors, e];
80+
}
81+
82+
return this;
83+
};
84+
85+
6386
withNearObject = (nearObjectObj) => {
6487
if (this.includesNearMediaFilter) {
6588
throw new Error(
@@ -172,8 +195,10 @@ export default class Getter {
172195
this.nearTextString ||
173196
this.nearObjectString ||
174197
this.nearVectorString ||
175-
this.askString ||
176198
this.nearImageString ||
199+
this.askString ||
200+
this.bm25String ||
201+
this.hybridString ||
177202
this.limit ||
178203
this.offset ||
179204
this.groupString ||
@@ -205,6 +230,14 @@ export default class Getter {
205230
args = [...args, `nearVector:${this.nearVectorString}`];
206231
}
207232

233+
if (this.bm25String) {
234+
args = [...args, `bm25:${this.bm25String}`];
235+
}
236+
237+
if (this.hybridString) {
238+
args = [...args, `hybrid:${this.hybridString}`];
239+
}
240+
208241
if (this.groupString) {
209242
args = [...args, `group:${this.groupString}`];
210243
}

graphql/getter.test.js

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,231 @@ describe("invalid sort filters", () => {
13771377
});
13781378
});
13791379
});
1380+
1381+
describe("bm25 valid searchers", () => {
1382+
const mockClient = {
1383+
query: jest.fn(),
1384+
};
1385+
1386+
test("query and no properties", () => {
1387+
const expectedQuery =
1388+
`{Get{Person` + `(bm25:{query:"accountant"})` + `{name}}}`;
1389+
1390+
new Getter(mockClient)
1391+
.withClassName("Person")
1392+
.withFields("name")
1393+
.withBm25({ query: "accountant" })
1394+
.do();
1395+
1396+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1397+
});
1398+
1399+
test("query and properties", () => {
1400+
const expectedQuery =
1401+
`{Get{Person` + `(bm25:{query:"accountant",properties:["profession","position"]})` + `{name}}}`;
1402+
1403+
new Getter(mockClient)
1404+
.withClassName("Person")
1405+
.withFields("name")
1406+
.withBm25({ query: "accountant", properties: ["profession", "position"] })
1407+
.do();
1408+
1409+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1410+
});
1411+
1412+
test("query and empty properties", () => {
1413+
const expectedQuery =
1414+
`{Get{Person` + `(bm25:{query:"accountant",properties:[]})` + `{name}}}`;
1415+
1416+
new Getter(mockClient)
1417+
.withClassName("Person")
1418+
.withFields("name")
1419+
.withBm25({ query: "accountant", properties: [] })
1420+
.do();
1421+
1422+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1423+
});
1424+
});
1425+
1426+
describe("bm25 invalid searchers", () => {
1427+
const mockClient = {
1428+
query: jest.fn(),
1429+
};
1430+
1431+
const tests = [
1432+
{
1433+
title: "an empty bm25",
1434+
bm25: {},
1435+
msg: "bm25 filter: query cannot be empty",
1436+
},
1437+
{
1438+
title: "an empty query",
1439+
bm25: { query: ""},
1440+
msg: "bm25 filter: query must be a string",
1441+
},
1442+
{
1443+
title: "query of wrong type",
1444+
bm25: { query: {} },
1445+
msg: "bm25 filter: query must be a string",
1446+
},
1447+
{
1448+
title: "an empty property",
1449+
bm25: { query: "query", properties: [""] },
1450+
msg: "bm25 filter: properties must be an array of strings",
1451+
},
1452+
{
1453+
title: "property of wrong type",
1454+
bm25: { query: "query", properties: [123] },
1455+
msg: "bm25 filter: properties must be an array of strings",
1456+
},
1457+
{
1458+
title: "properties of wrong type",
1459+
bm25: { query: "query", properties: {} },
1460+
msg: "bm25 filter: properties must be an array of strings",
1461+
},
1462+
];
1463+
1464+
tests.forEach((t) => {
1465+
test(t.title, () => {
1466+
new Getter(mockClient)
1467+
.withClassName("Person")
1468+
.withFields("name")
1469+
.withBm25(t.bm25)
1470+
.do()
1471+
.then(() => fail("it should have error'd"))
1472+
.catch((e) => {
1473+
expect(e.toString()).toContain(t.msg);
1474+
});
1475+
});
1476+
});
1477+
});
1478+
1479+
1480+
describe("hybrid valid searchers", () => {
1481+
const mockClient = {
1482+
query: jest.fn(),
1483+
};
1484+
1485+
test("query and no alpha, no vector", () => {
1486+
const expectedQuery =
1487+
`{Get{Person` + `(hybrid:{query:"accountant"})` + `{name}}}`;
1488+
1489+
new Getter(mockClient)
1490+
.withClassName("Person")
1491+
.withFields("name")
1492+
.withHybrid({ query: "accountant" })
1493+
.do();
1494+
1495+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1496+
});
1497+
1498+
test("query and alpha, no vector", () => {
1499+
const expectedQuery =
1500+
`{Get{Person` + `(hybrid:{query:"accountant",alpha:0.75})` + `{name}}}`;
1501+
1502+
new Getter(mockClient)
1503+
.withClassName("Person")
1504+
.withFields("name")
1505+
.withHybrid({ query: "accountant", alpha: 0.75 })
1506+
.do();
1507+
1508+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1509+
});
1510+
1511+
test("query and alpha 0, no vector", () => {
1512+
const expectedQuery =
1513+
`{Get{Person` + `(hybrid:{query:"accountant",alpha:0})` + `{name}}}`;
1514+
1515+
new Getter(mockClient)
1516+
.withClassName("Person")
1517+
.withFields("name")
1518+
.withHybrid({ query: "accountant", alpha: 0 })
1519+
.do();
1520+
1521+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1522+
});
1523+
1524+
test("query and vector, no alpha", () => {
1525+
const expectedQuery =
1526+
`{Get{Person` + `(hybrid:{query:"accountant",vector:[1,2,3]})` + `{name}}}`;
1527+
1528+
new Getter(mockClient)
1529+
.withClassName("Person")
1530+
.withFields("name")
1531+
.withHybrid({ query: "accountant", vector: [1,2,3] })
1532+
.do();
1533+
1534+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1535+
});
1536+
1537+
test("query and alpha and vector", () => {
1538+
const expectedQuery =
1539+
`{Get{Person` + `(hybrid:{query:"accountant",alpha:0.75,vector:[1,2,3]})` + `{name}}}`;
1540+
1541+
new Getter(mockClient)
1542+
.withClassName("Person")
1543+
.withFields("name")
1544+
.withHybrid({ query: "accountant", alpha: 0.75, vector: [1,2,3] })
1545+
.do();
1546+
1547+
expect(mockClient.query).toHaveBeenCalledWith(expectedQuery);
1548+
});
1549+
});
1550+
1551+
describe("hybrid invalid searchers", () => {
1552+
const mockClient = {
1553+
query: jest.fn(),
1554+
};
1555+
1556+
const tests = [
1557+
{
1558+
title: "an empty hybrid",
1559+
hybrid: {},
1560+
msg: "hybrid filter: query cannot be empty",
1561+
},
1562+
{
1563+
title: "an empty query",
1564+
hybrid: { query: ""},
1565+
msg: "hybrid filter: query must be a string",
1566+
},
1567+
{
1568+
title: "query of wrong type",
1569+
hybrid: { query: {} },
1570+
msg: "hybrid filter: query must be a string",
1571+
},
1572+
{
1573+
title: "alpha on wrong type",
1574+
hybrid: { query: "query", alpha: "alpha" },
1575+
msg: "hybrid filter: alpha must be a number",
1576+
},
1577+
{
1578+
title: "an empty vector",
1579+
hybrid: { query: "query", vector: [] },
1580+
msg: "hybrid filter: vector must be an array of numbers",
1581+
},
1582+
{
1583+
title: "vector element of wrong type",
1584+
hybrid: { query: "query", vector: ["vector"] },
1585+
msg: "hybrid filter: vector must be an array of numbers",
1586+
},
1587+
{
1588+
title: "vector of wrong type",
1589+
hybrid: { query: "query", vector: {} },
1590+
msg: "hybrid filter: vector must be an array of numbers",
1591+
},
1592+
];
1593+
1594+
tests.forEach((t) => {
1595+
test(t.title, () => {
1596+
new Getter(mockClient)
1597+
.withClassName("Person")
1598+
.withFields("name")
1599+
.withHybrid(t.hybrid)
1600+
.do()
1601+
.then(() => fail("it should have error'd"))
1602+
.catch((e) => {
1603+
expect(e.toString()).toContain(t.msg);
1604+
});
1605+
});
1606+
});
1607+
});

0 commit comments

Comments
 (0)