Skip to content

Commit 04730e3

Browse files
feat: ensure retriever returns an image and send it to the LLM base64 encoded
1 parent 399e394 commit 04730e3

File tree

12 files changed

+144
-340
lines changed

12 files changed

+144
-340
lines changed

examples/multimodal/data/1.jpg

-1.96 MB
Binary file not shown.

examples/multimodal/data/2.jpg

-4.77 MB
Binary file not shown.

examples/multimodal/data/3.jpg

-6.66 MB
Binary file not shown.

examples/multimodal/data/San Francisco.txt

Lines changed: 0 additions & 323 deletions
This file was deleted.

examples/multimodal/rag.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ async function main() {
3131

3232
const queryEngine = index.asQueryEngine({
3333
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
34-
// TODO: set imageSimilarityTopK: 1,
35-
retriever: index.asRetriever({ similarityTopK: 2 }),
34+
// TODO: set text similarity to a higher value than image similarity
35+
retriever: index.asRetriever({ similarityTopK: 1 }),
3636
});
3737
const result = await queryEngine.query(
3838
"what are Vincent van Gogh's famous paintings",

examples/multimodal/retrieve.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import {
55
TextNode,
66
VectorStoreIndex,
77
} from "llamaindex";
8-
import * as path from "path";
98

109
export async function createIndex() {
1110
// set up vector store index with two vector stores, one for text, the other for images
@@ -37,7 +36,7 @@ async function main() {
3736
continue;
3837
}
3938
if (node instanceof ImageNode) {
40-
console.log(`Image: ${path.join(__dirname, node.id_)}`);
39+
console.log(`Image: ${node.getUrl()}`);
4140
} else if (node instanceof TextNode) {
4241
console.log("Text:", (node as TextNode).text.substring(0, 128));
4342
}

packages/core/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"@xenova/transformers": "^2.10.0",
1111
"assemblyai": "^4.0.0",
1212
"crypto-js": "^4.2.0",
13+
"file-type": "^18.7.0",
1314
"js-tiktoken": "^1.0.8",
1415
"lodash": "^4.17.21",
1516
"mammoth": "^1.6.0",

packages/core/src/Node.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import CryptoJS from "crypto-js";
2+
import path from "path";
23
import { v4 as uuidv4 } from "uuid";
34

45
export enum NodeRelationship {
@@ -304,6 +305,12 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
304305
getType(): ObjectType {
305306
return ObjectType.IMAGE;
306307
}
308+
309+
getUrl(): URL {
310+
// id_ stores the relative path, convert it to the URL of the file
311+
const absPath = path.resolve(this.id_);
312+
return new URL(`file://${absPath}`);
313+
}
307314
}
308315

309316
export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {

packages/core/src/embeddings/utils.ts

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import _ from "lodash";
22
import { ImageType } from "../Node";
33
import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
4-
import { VectorStoreQueryMode } from "../storage";
4+
import { DEFAULT_FS, VectorStoreQueryMode } from "../storage";
55
import { SimilarityType } from "./types";
66

77
/**
@@ -185,6 +185,16 @@ export function getTopKMMREmbeddings(
185185
return [resultSimilarities, resultIds];
186186
}
187187

188+
async function blobToDataUrl(input: Blob) {
189+
const { fileTypeFromBuffer } = await import("file-type");
190+
const buffer = Buffer.from(await input.arrayBuffer());
191+
const type = await fileTypeFromBuffer(buffer);
192+
if (!type) {
193+
throw new Error("Unsupported image type");
194+
}
195+
return "data:" + type.mime + ";base64," + buffer.toString("base64");
196+
}
197+
188198
export async function readImage(input: ImageType) {
189199
const { RawImage } = await import("@xenova/transformers");
190200
if (input instanceof Blob) {
@@ -195,3 +205,53 @@ export async function readImage(input: ImageType) {
195205
throw new Error(`Unsupported input type: ${typeof input}`);
196206
}
197207
}
208+
209+
export async function imageToString(input: ImageType): Promise<string> {
210+
if (input instanceof Blob) {
211+
// if the image is a Blob, convert it to a base64 data URL
212+
return await blobToDataUrl(input);
213+
} else if (_.isString(input)) {
214+
return input;
215+
} else if (input instanceof URL) {
216+
return input.toString();
217+
} else {
218+
throw new Error(`Unsupported input type: ${typeof input}`);
219+
}
220+
}
221+
222+
export function stringToImage(input: string): ImageType {
223+
if (input.startsWith("data:")) {
224+
// if the input is a base64 data URL, convert it back to a Blob
225+
const base64Data = input.split(",")[1];
226+
const byteArray = Buffer.from(base64Data, "base64");
227+
return new Blob([byteArray]);
228+
} else if (input.startsWith("http://") || input.startsWith("https://")) {
229+
return new URL(input);
230+
} else if (_.isString(input)) {
231+
return input;
232+
} else {
233+
throw new Error(`Unsupported input type: ${typeof input}`);
234+
}
235+
}
236+
237+
export async function imageToDataUrl(input: ImageType): Promise<string> {
238+
// first ensure, that the input is a Blob
239+
if (
240+
(input instanceof URL && input.protocol === "file:") ||
241+
_.isString(input)
242+
) {
243+
// string or file URL
244+
const fs = DEFAULT_FS;
245+
const dataBuffer = await fs.readFile(
246+
input instanceof URL ? input.pathname : input,
247+
);
248+
input = new Blob([dataBuffer]);
249+
} else if (!(input instanceof Blob)) {
250+
if (input instanceof URL) {
251+
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
252+
} else {
253+
throw new Error(`Unsupported input type: ${typeof input}`);
254+
}
255+
}
256+
return await blobToDataUrl(input);
257+
}

packages/core/src/indices/vectorStore/VectorIndexRetriever.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import { Event } from "../../callbacks/CallbackManager";
2-
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
3-
import { BaseEmbedding } from "../../embeddings";
41
import { globalsHelper } from "../../GlobalsHelper";
5-
import { Metadata, NodeWithScore } from "../../Node";
2+
import { ImageNode, Metadata, NodeWithScore } from "../../Node";
63
import { BaseRetriever } from "../../Retriever";
74
import { ServiceContext } from "../../ServiceContext";
5+
import { Event } from "../../callbacks/CallbackManager";
6+
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
7+
import { BaseEmbedding } from "../../embeddings";
88
import {
99
VectorStoreQuery,
1010
VectorStoreQueryMode,
@@ -108,6 +108,12 @@ export class VectorIndexRetriever implements BaseRetriever {
108108
}
109109

110110
const node = this.index.indexStruct.nodesDict[result.ids[i]];
111+
// XXX: Hack, if it's an image node, we reconstruct the image from the URL
112+
// Alternative: Store image in doc store and retrieve it here
113+
if (node instanceof ImageNode) {
114+
node.image = node.getUrl();
115+
}
116+
111117
nodesWithScores.push({
112118
node: node,
113119
score: result.similarities[i],

0 commit comments

Comments
 (0)