From 54b1f5d11c929aed9d6d939e0e77b9d020fe0a1b Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Thu, 23 Oct 2025 22:54:50 -0400 Subject: [PATCH 1/4] [Refactor] Refactor for compatibility with TVM FFI updates --- src/cache_util.ts | 6 ++-- src/embedding.ts | 4 +-- src/engine.ts | 2 +- src/llm_chat.ts | 72 +++++++++++++++++++++++------------------------ 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/cache_util.ts b/src/cache_util.ts index fce49ff5..22245bdc 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -29,7 +29,7 @@ export async function hasModelInCache( const modelRecord = findModelRecord(modelId, appConfig); const modelUrl = cleanModelUrl(modelRecord.model); const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache"; - return tvmjs.hasNDArrayInCache(modelUrl, "webllm/model", cacheType); + return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType); } export async function deleteModelAllInfoInCache( @@ -60,10 +60,10 @@ export async function deleteModelInCache( const modelUrl = cleanModelUrl(modelRecord.model); let modelCache: tvmjs.ArtifactCacheTemplate; if (appConfig.useIndexedDBCache) { - tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "indexeddb"); + tvmjs.deleteTensorCache(modelUrl, "webllm/model", "indexeddb"); modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); } else { - tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "cache"); + tvmjs.deleteTensorCache(modelUrl, "webllm/model", "cache"); modelCache = new tvmjs.ArtifactCache("webllm/model"); } await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href); diff --git a/src/embedding.ts b/src/embedding.ts index ae5d9123..38d54a1e 100644 --- a/src/embedding.ts +++ b/src/embedding.ts @@ -204,7 +204,7 @@ export class EmbeddingPipeline { maskNDArray = maskNDArray.view([curBatchSize, maxInputSize]); // 3.5 Actual forwarding on GPU, logits of shape (curBatchSize, maxInputSize, hidden_size) - const logitsCurBatchOnGPU: tvmjs.NDArray = this.prefill( + const logitsCurBatchOnGPU: tvmjs.Tensor = this.prefill( inputNDArray, maskNDArray, this.params, @@ -213,7 +213,7 @@ export class EmbeddingPipeline { // 3.6 Copy logits to CPU, flatten to curBatchSize * maxInputSize * hidden_size const hidden_size = logitsCurBatchOnGPU.shape[2]; - let logitsCurBatchOnCPU: tvmjs.NDArray = this.tvm.empty( + let logitsCurBatchOnCPU: tvmjs.Tensor = this.tvm.empty( logitsCurBatchOnGPU.shape, logitsCurBatchOnGPU.dtype, this.tvm.cpu(), diff --git a/src/engine.ts b/src/engine.ts index 6609f3e5..ddf6d2f2 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -367,7 +367,7 @@ export class MLCEngine implements MLCEngineInterface { this.logger, ); const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache"; - await tvm.fetchNDArrayCache( + await tvm.fetchTensorCache( modelUrl, tvm.webgpu(), "webllm/model", diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 8f7620a8..056cd847 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -65,7 +65,7 @@ export class LLMChatPipeline { // parameter states private params: tvmjs.TVMObject; private kvCache: tvmjs.TVMObject; - private logitsOnCPU?: tvmjs.NDArray = undefined; + private logitsOnCPU?: tvmjs.Tensor = undefined; private filledKVCacheLength = 0; // meta data @@ -224,7 +224,7 @@ export class LLMChatPipeline { // 2. Get json stored in the vm's metadata function const fgetMetadata = this.vm.getFunction("_metadata"); const ret_value = fgetMetadata(); - const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString(); + const metadataStr = ret_value.toString(); const metadata = JSON.parse(metadataStr); // 3. Load parameters by name @@ -671,7 +671,7 @@ export class LLMChatPipeline { // 2. Prefill each chunk this.tvm.beginScope(); - let logits: tvmjs.NDArray; + let logits: tvmjs.Tensor; for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; const chunkLen = chunkLens[i]; @@ -860,7 +860,7 @@ export class LLMChatPipeline { * @note precondition: inputTokens.length <= prefillChunkSize, since we take care of * chunking in `getChunkedPrefillInputData()`. */ - private getTokensEmbeddings(inputTokens: number[]): tvmjs.NDArray { + private getTokensEmbeddings(inputTokens: number[]): tvmjs.Tensor { this.tvm.beginScope(); if (inputTokens.length > this.prefillChunkSize) { throw new Error( @@ -873,7 +873,7 @@ export class LLMChatPipeline { this.device, ); inputData.copyFrom(inputTokens); - const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope( + const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope( this.embed!(inputData, this.params), ); this.tvm.endScope(); @@ -886,9 +886,9 @@ export class LLMChatPipeline { */ private async getImageEmbeddings( inputImage: ImageURL, - ): Promise { + ): Promise { this.tvm.beginScope(); - // 1. Transform ImageURL into image input in NDArray + // 1. Transform ImageURL into image input in TVMArray const url = inputImage.url; // url starting with `data:image` and `http` share the same loading method const imgData: ImageData = await getImageDataFromURL(url); @@ -900,7 +900,7 @@ export class LLMChatPipeline { .view([1, imgData.height, imgData.width, 3]); // NHWC // 2. Call image embed kernel - const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope( + const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope( this.image_embed!(pixelArray, this.params), ); if (embed.shape[0] !== IMAGE_EMBED_SIZE) { @@ -920,14 +920,14 @@ export class LLMChatPipeline { * * @param inputData data to embed and forward * @param inputDataLen length of this inputData, should smaller than prefill chunk size. - * @returns The logits returned by this forward as tvmjs.NDArray on GPU. + * @returns The logits returned by this forward as tvmjs.Tensor on GPU. * * @note Precondition: inputData's data length is smaller than prefill chunk size */ private async embedAndForward( inputData: Array | ImageURL>, inputDataLen: number, - ): Promise { + ): Promise { if (inputDataLen > this.prefillChunkSize) { throw new Error( "InternalError: expect inputDataLen <= this.prefillChunkSize.", @@ -938,18 +938,18 @@ export class LLMChatPipeline { // 1. Embed all inputData this.tvm.beginScope(); - const embeddings: tvmjs.NDArray[] = []; + const embeddings: tvmjs.Tensor[] = []; for (let i = 0; i < inputData.length; i++) { const data = inputData[i]; if (Array.isArray(data)) { - embeddings.push(this.getTokensEmbeddings(data)); + embeddings.push(await this.getTokensEmbeddings(data)); } else { embeddings.push(await this.getImageEmbeddings(data)); } } // 2. Concatenate embeddings - let allEmbeddings: tvmjs.NDArray; + let allEmbeddings: tvmjs.Tensor; if (embeddings.length === 1) { allEmbeddings = embeddings[0]; } else { @@ -983,7 +983,7 @@ export class LLMChatPipeline { } // NOTE: caller must call device.sync() - private updateLogitsOnCPU(logits: tvmjs.NDArray): tvmjs.NDArray { + private updateLogitsOnCPU(logits: tvmjs.Tensor): tvmjs.Tensor { if (this.logitsOnCPU == undefined) { this.logitsOnCPU = this.tvm.detachFromCurrentScope( this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()), @@ -998,7 +998,7 @@ export class LLMChatPipeline { } private async sampleTokenFromLogits( - logitsOnGPU: tvmjs.NDArray, + logitsOnGPU: tvmjs.Tensor, genConfig?: GenerationConfig, ) { // 0. Get value of temperature, top_p, and various penalties, possibly overridden by genConfig @@ -1160,7 +1160,7 @@ export class LLMChatPipeline { const logitBiasBegin = performance.now(); const numTokens = Object.keys(logit_bias ?? {}).length; - const pos2seq_id = new Int32Array(numTokens).fill(0); + const pos2seqIds = new Int32Array(numTokens).fill(0); const tokenIds = new Int32Array(numTokens); const tokenLogitBias = new Float32Array(numTokens); @@ -1173,23 +1173,23 @@ export class LLMChatPipeline { this.tvm.beginScope(); - const pos2seqIdsArray = this.tvm + const pos2seqIdsDevice = this.tvm .empty([numTokens], "int32", this.device) - .copyFrom(pos2seq_id); + .copyFrom(pos2seqIds); - const tokenIdsArray = this.tvm + const tokenIdsDevice = this.tvm .empty([numTokens], "int32", this.device) .copyFrom(tokenIds); - const tokenLogitBiasArray = this.tvm + const tokenLogitBiasDevice = this.tvm .empty([numTokens], "float32", this.device) .copyFrom(tokenLogitBias); this.fapplyLogitBias( logitsOnGPU.view([1, this.fullVocabSize]), - pos2seqIdsArray, - tokenIdsArray, - tokenLogitBiasArray, + pos2seqIdsDevice, + tokenIdsDevice, + tokenLogitBiasDevice, ); this.tvm.endScope(); @@ -1215,7 +1215,7 @@ export class LLMChatPipeline { if (numTokens > 0) { const penaltyBegin = performance.now(); - const pos2seq_id = new Int32Array(numTokens).fill(0); + const pos2seqIds = new Int32Array(numTokens).fill(0); const tokenIds = new Int32Array(numTokens).fill(0); const tokenCnt = new Int32Array(numTokens).fill(0); const penalties = new Float32Array([ @@ -1232,29 +1232,29 @@ export class LLMChatPipeline { .empty([1], "int32", this.device) .copyFrom([0]); - const pos2seqIdsArray = this.tvm + const pos2seqIdsDevice = this.tvm .empty([numTokens], "int32", this.device) - .copyFrom(pos2seq_id); + .copyFrom(pos2seqIds); - const tokenIdsArray = this.tvm + const tokenIdsDevice = this.tvm .empty([numTokens], "int32", this.device) .copyFrom(tokenIds); - const tokenCntArray = this.tvm + const tokenCntDevice = this.tvm .empty([numTokens], "int32", this.device) .copyFrom(tokenCnt); - const penaltiesArray = this.tvm + const penaltiesDevice = this.tvm .empty([1, 3], "float32", this.device) .copyFrom(penalties); this.fapplyPenalty( logitsOnGPU.view([1, this.fullVocabSize]), seqIdsArray, - pos2seqIdsArray, - tokenIdsArray, - tokenCntArray, - penaltiesArray, + pos2seqIdsDevice, + tokenIdsDevice, + tokenCntDevice, + penaltiesDevice, ); this.tvm.endScope(); @@ -1280,13 +1280,13 @@ export class LLMChatPipeline { const temperatures = new Float32Array([temperature]); this.tvm.beginScope(); - const temperaturesArray = this.tvm + const temperaturesDevice = this.tvm .empty([numSeqs], "float32", this.device) .copyFrom(temperatures); const probs = this.fsoftmaxWithTemperature( logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]), - temperaturesArray, + temperaturesDevice, ); this.updateLogitsOnCPU(probs); this.tvm.endScope(); @@ -1458,7 +1458,7 @@ export class LLMChatPipeline { const chunkLens: Array = retGetChunks[1]; // 2. Prefill each chunk - let logitsOnGPU: tvmjs.NDArray; + let logitsOnGPU: tvmjs.Tensor; for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; const chunkLen = chunkLens[i]; From f6b29792bdbea7d5ac18178133aa0beb8bf4e9cc Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Fri, 24 Oct 2025 15:35:54 -0400 Subject: [PATCH 2/4] Add support for cross-origin storage caching --- src/cache_util.ts | 169 ++++++++++++++++++---- src/config.ts | 10 ++ src/cross_origin_storage.ts | 225 ++++++++++++++++++++++++++++++ src/cross_origin_storage_cache.ts | 92 ++++++++++++ src/engine.ts | 27 ++-- src/utils.ts | 3 + 6 files changed, 484 insertions(+), 42 deletions(-) create mode 100644 src/cross_origin_storage.ts create mode 100644 src/cross_origin_storage_cache.ts diff --git a/src/cache_util.ts b/src/cache_util.ts index 22245bdc..027e3fbf 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -4,10 +4,139 @@ import { ChatConfig, ModelRecord, prebuiltAppConfig, + getCacheBackend, } from "./config"; import { cleanModelUrl } from "./support"; import { ModelNotFoundError, UnsupportedTokenizerFilesError } from "./error"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; +import CrossOriginStorage from "./cross_origin_storage"; +import CrossOriginStorageCache from "./cross_origin_storage_cache"; + +type CacheScope = "webllm/model" | "webllm/config" | "webllm/wasm"; + +let crossOriginUnavailableLogged = false; + +function shouldUseCrossOrigin(appConfig: AppConfig): boolean { + return ( + getCacheBackend(appConfig) === "cross-origin" && + CrossOriginStorage.isAvailable() + ); +} + +export function getArtifactCache( + scope: CacheScope, + appConfig: AppConfig, + logger: (msg: string) => void = console.warn, +): tvmjs.ArtifactCacheTemplate { + const backend = getCacheBackend(appConfig); + if (backend === "cross-origin") { + if (CrossOriginStorage.isAvailable()) { + return new CrossOriginStorageCache(scope); + } + // Fallback to Cache API + if (!crossOriginUnavailableLogged) { + logger( + "Cross-origin storage backend requested but unavailable; falling back to Cache API.", + ); + crossOriginUnavailableLogged = true; + } + } + if (backend === "indexeddb") { + return new tvmjs.ArtifactIndexedDBCache(scope); + } + return new tvmjs.ArtifactCache(scope); +} + +async function hasTensorCache( + cache: tvmjs.ArtifactCacheTemplate, + tensorCacheUrl: string, +): Promise { + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const hasManifest = await cache.hasAllKeys([jsonUrl]); + if (!hasManifest) { + return false; + } + const manifest = await cache.fetchWithCache(jsonUrl, "json"); + const records = manifest?.records ?? []; + if (!Array.isArray(records) || records.length === 0) { + return false; + } + const shardUrls = records.map( + (entry: { dataPath: string }) => + new URL(entry.dataPath, tensorCacheUrl).href, + ); + return cache.hasAllKeys(shardUrls); +} + +async function deleteTensorCacheEntries( + cache: tvmjs.ArtifactCacheTemplate, + tensorCacheUrl: string, +): Promise { + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const hasManifest = await cache.hasAllKeys([jsonUrl]); + if (!hasManifest) { + return; + } + let manifest: { records?: Array<{ dataPath: string }> }; + try { + manifest = await cache.fetchWithCache(jsonUrl, "json"); + } catch (err) { + return; + } + const records = manifest?.records ?? []; + await Promise.all( + records.map(async (entry) => { + if (!entry?.dataPath) { + return; + } + const dataUrl = new URL(entry.dataPath, tensorCacheUrl).href; + await cache.deleteInCache(dataUrl); + }), + ); + await cache.deleteInCache(jsonUrl); +} + +export async function fetchModelArtifacts( + tvm: tvmjs.Instance, + tensorCacheUrl: string, + device: tvmjs.DLDevice, + appConfig: AppConfig, + signal?: AbortSignal, +): Promise { + if (!shouldUseCrossOrigin(appConfig)) { + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; + return tvm.fetchTensorCache( + tensorCacheUrl, + device, + "webllm/model", + cacheType, + signal, + ); + } + + const artifactCache = getArtifactCache("webllm/model", appConfig); + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; + const manifest = await artifactCache.fetchWithCache(jsonUrl, "json", signal); + const records = ( + Array.isArray(manifest?.records) ? manifest.records : [] + ) as Array; + await (tvm as any).fetchTensorCacheInternal( + tensorCacheUrl, + records, + device, + artifactCache, + signal, + ); + if (manifest?.metadata !== undefined) { + const runtime = tvm as any; + runtime.cacheMetadata = { + ...runtime.cacheMetadata, + ...(manifest.metadata as Record), + }; + } + return manifest; +} function findModelRecord(modelId: string, appConfig?: AppConfig): ModelRecord { const matchedItem = appConfig?.model_list.find( @@ -28,7 +157,12 @@ export async function hasModelInCache( } const modelRecord = findModelRecord(modelId, appConfig); const modelUrl = cleanModelUrl(modelRecord.model); - const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache"; + if (shouldUseCrossOrigin(appConfig)) { + const cache = getArtifactCache("webllm/model", appConfig); + return hasTensorCache(cache, modelUrl); + } + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType); } @@ -58,13 +192,13 @@ export async function deleteModelInCache( } const modelRecord = findModelRecord(modelId, appConfig); const modelUrl = cleanModelUrl(modelRecord.model); - let modelCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - tvmjs.deleteTensorCache(modelUrl, "webllm/model", "indexeddb"); - modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); + const modelCache = getArtifactCache("webllm/model", appConfig); + if (shouldUseCrossOrigin(appConfig)) { + await deleteTensorCacheEntries(modelCache, modelUrl); } else { - tvmjs.deleteTensorCache(modelUrl, "webllm/model", "cache"); - modelCache = new tvmjs.ArtifactCache("webllm/model"); + const backend = getCacheBackend(appConfig); + const cacheType = backend === "indexeddb" ? "indexeddb" : "cache"; + await tvmjs.deleteTensorCache(modelUrl, "webllm/model", cacheType); } await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href); await modelCache.deleteInCache(new URL("tokenizer.json", modelUrl).href); @@ -79,12 +213,7 @@ export async function deleteChatConfigInCache( appConfig = prebuiltAppConfig; } const modelRecord = findModelRecord(modelId, appConfig); - let configCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config"); - } else { - configCache = new tvmjs.ArtifactCache("webllm/config"); - } + const configCache = getArtifactCache("webllm/config", appConfig); const modelUrl = cleanModelUrl(modelRecord.model); const configUrl = new URL("mlc-chat-config.json", modelUrl).href; await configCache.deleteInCache(configUrl); @@ -99,12 +228,7 @@ export async function deleteModelWasmInCache( appConfig = prebuiltAppConfig; } const modelRecord = findModelRecord(modelId, appConfig); - let wasmCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - wasmCache = new tvmjs.ArtifactIndexedDBCache("webllm/wasm"); - } else { - wasmCache = new tvmjs.ArtifactCache("webllm/wasm"); - } + const wasmCache = getArtifactCache("webllm/wasm", appConfig); await wasmCache.deleteInCache(modelRecord.model_lib); } @@ -122,12 +246,7 @@ export async function asyncLoadTokenizer( appConfig: AppConfig, logger: (msg: string) => void = console.log, ): Promise { - let modelCache: tvmjs.ArtifactCacheTemplate; - if (appConfig.useIndexedDBCache) { - modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model"); - } else { - modelCache = new tvmjs.ArtifactCache("webllm/model"); - } + const modelCache = getArtifactCache("webllm/model", appConfig, logger); if (config.tokenizer_files.includes("tokenizer.json")) { const url = new URL("tokenizer.json", baseUrl).href; diff --git a/src/config.ts b/src/config.ts index 0ca51eba..eaff9e02 100644 --- a/src/config.ts +++ b/src/config.ts @@ -276,9 +276,19 @@ export interface ModelRecord { * * @note Note that the Cache API is more well-tested in WebLLM as of now. */ +export type CacheBackend = "cache" | "indexeddb" | "cross-origin"; + export interface AppConfig { model_list: Array; useIndexedDBCache?: boolean; + cacheBackend?: CacheBackend; +} + +export function getCacheBackend(appConfig: AppConfig): CacheBackend { + if (appConfig.cacheBackend !== undefined) { + return appConfig.cacheBackend; + } + return appConfig.useIndexedDBCache ? "indexeddb" : "cache"; } /** diff --git a/src/cross_origin_storage.ts b/src/cross_origin_storage.ts new file mode 100644 index 00000000..9d4c2659 --- /dev/null +++ b/src/cross_origin_storage.ts @@ -0,0 +1,225 @@ +const HASH_ALGORITHM = "SHA-256"; +const HASH_MATCH_REGEX = /[A-Fa-f0-9]{64}/; + +export interface CrossOriginHashDescriptor { + algorithm: string; + value: string; +} + +interface CrossOriginStorageHandle { + getFile(): Promise; + createWritable(): Promise; +} + +interface CrossOriginStorageAPI { + requestFileHandles( + descriptors: CrossOriginHashDescriptor[], + options?: { create?: boolean }, + ): Promise; + removeFileHandles?(descriptors: CrossOriginHashDescriptor[]): Promise; +} + +type RequestLike = string | URL | Request | { url?: string }; + +declare global { + interface Navigator { + crossOriginStorage?: CrossOriginStorageAPI; + } +} + +export default class CrossOriginStorage { + private hashCache: Map; + + constructor() { + this.hashCache = new Map(); + } + + static isAvailable(): boolean { + return ( + typeof navigator !== "undefined" && + "crossOriginStorage" in navigator && + navigator.crossOriginStorage !== undefined + ); + } + + async match(request: RequestLike): Promise { + const url = this.normalizeRequest(request); + const hash = await this.resolveHashDescriptor(url); + if (!hash) { + return undefined; + } + try { + const api = this.getApi(); + if (!api) { + return undefined; + } + const handles = await api.requestFileHandles([hash]); + const handle = handles[0]; + if (!handle) { + return undefined; + } + const blob = await handle.getFile(); + return new Response(blob); + } catch { + return undefined; + } + } + + async put(request: RequestLike, response: Response): Promise { + const url = this.normalizeRequest(request); + const blob = await response.blob(); + const hash = await this.getBlobHash(blob); + const api = this.getApi(); + if (!api) { + throw new Error("Cross-origin storage API unavailable."); + } + const handles = await api.requestFileHandles([hash], { create: true }); + const handle = handles[0]; + if (!handle) { + throw new Error("Cross-origin storage API returned no handles."); + } + const writableStream = await handle.createWritable(); + await writableStream.write(blob); + await writableStream.close(); + this.hashCache.set(url, hash); + } + + async delete(request: RequestLike): Promise { + const url = this.normalizeRequest(request); + const hash = await this.resolveHashDescriptor(url); + if (!hash) { + return; + } + const api = this.getApi(); + if (api && typeof api.removeFileHandles === "function") { + await api.removeFileHandles([hash]); + } + this.hashCache.delete(url); + } + + private getApi(): CrossOriginStorageAPI | undefined { + if (!CrossOriginStorage.isAvailable()) { + return undefined; + } + return navigator.crossOriginStorage; + } + + private normalizeRequest(request: RequestLike): string { + if (typeof request === "string") { + return request; + } + if (request instanceof URL) { + return request.href; + } + if (request instanceof Request) { + return request.url; + } + if (request && typeof request.url === "string") { + return request.url; + } + throw new Error("CrossOriginStorage: Unsupported request type."); + } + + private async resolveHashDescriptor( + url: string, + ): Promise { + const cached = this.hashCache.get(url); + if (cached) { + return cached; + } + const hashValue = await this.getFileHash(url); + if (!hashValue) { + return null; + } + const descriptor: CrossOriginHashDescriptor = { + algorithm: HASH_ALGORITHM, + value: hashValue, + }; + this.hashCache.set(url, descriptor); + return descriptor; + } + + // Gets the SHA-256 hash for large resources using request metadata. + private async getFileHash(url: string): Promise { + const metadataHash = await this.extractHashFromHead(url); + if (metadataHash) { + return metadataHash; + } + if (/\/resolve\/main\//.test(url)) { + const pointerHash = await this.extractHashFromPointer(url); + if (pointerHash) { + return pointerHash; + } + } + return null; + } + + private async extractHashFromHead(url: string): Promise { + try { + const response = await fetch(url, { method: "HEAD" }); + if (!response.ok) { + return null; + } + const headerNames = [ + "x-linked-etag", + "x-linked-hash", + "x-amz-meta-sha256", + "x-oss-meta-sha256", + "x-sha256", + "etag", + ]; + for (const name of headerNames) { + const value = response.headers.get(name); + const hash = this.extractSha256(value); + if (hash) { + return hash; + } + } + } catch { + // Swallow errors; fall back to other strategies. + } + return null; + } + + private async extractHashFromPointer(url: string): Promise { + try { + const rawUrl = url.replace(/\/resolve\//, "/raw/"); + const response = await fetch(rawUrl, { + headers: { Range: "bytes=0-1023" }, + }); + if (!response.ok) { + return null; + } + const text = await response.text(); + if (!text.includes("oid sha256:")) { + return null; + } + const match = text.match(/oid sha256:([A-Fa-f0-9]+)/); + return match ? match[1] : null; + } catch { + return null; + } + } + + private extractSha256(value: string | null): string | null { + if (!value) { + return null; + } + const match = value.match(HASH_MATCH_REGEX); + return match ? match[0].toLowerCase() : null; + } + + private async getBlobHash(blob: Blob): Promise { + const arrayBuffer = await blob.arrayBuffer(); + const hashBuffer = await crypto.subtle.digest(HASH_ALGORITHM, arrayBuffer); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const hashHex = hashArray + .map((byte) => byte.toString(16).padStart(2, "0")) + .join(""); + + return { + algorithm: HASH_ALGORITHM, + value: hashHex, + }; + } +} diff --git a/src/cross_origin_storage_cache.ts b/src/cross_origin_storage_cache.ts new file mode 100644 index 00000000..4bfab8f3 --- /dev/null +++ b/src/cross_origin_storage_cache.ts @@ -0,0 +1,92 @@ +import * as tvmjs from "@mlc-ai/web-runtime"; +import CrossOriginStorage from "./cross_origin_storage"; + +type StoreType = string | undefined; + +const DEFAULT_FETCH_OPTIONS: RequestInit = { method: "GET" }; + +export class CrossOriginStorageCache implements tvmjs.ArtifactCacheTemplate { + private storage: CrossOriginStorage; + + constructor( + _scope: string, + storage: CrossOriginStorage = new CrossOriginStorage(), + ) { + this.storage = storage; + } + + async fetchWithCache( + url: string, + storetype?: StoreType, + signal?: AbortSignal, + ): Promise { + const cachedResponse = await this.storage.match(url); + if (cachedResponse !== undefined) { + return this.responseToStoreType(cachedResponse, storetype); + } + + await this.addToCache(url, storetype, signal); + const hydrated = await this.storage.match(url); + if (hydrated === undefined) { + throw new Error(`CrossOriginStorageCache: failed to hydrate ${url}`); + } + return this.responseToStoreType(hydrated, storetype); + } + + async addToCache( + url: string, + _storetype?: StoreType, + signal?: AbortSignal, + ): Promise { + const existing = await this.storage.match(url); + if (existing !== undefined) { + return; + } + const request = new Request( + url, + signal ? { ...DEFAULT_FETCH_OPTIONS, signal } : DEFAULT_FETCH_OPTIONS, + ); + const response = await fetch(request); + if (!response.ok) { + throw new Error( + `CrossOriginStorageCache: Unable to fetch ${url}, received status ${response.status}`, + ); + } + const cloned = response.clone(); + await this.storage.put(url, cloned); + } + + async hasAllKeys(keys: string[]): Promise { + const results = await Promise.all( + keys.map(async (key) => { + const cached = await this.storage.match(key); + return cached !== undefined; + }), + ); + return results.every((item) => item); + } + + async deleteInCache(_url: string): Promise { + // no delete API currently provided by Cross-Origin Storage + return; + } + + private async responseToStoreType( + response: Response, + storetype?: StoreType, + ): Promise { + if (storetype === undefined) { + return response; + } + const format = storetype.toLowerCase(); + if (format === "json") { + return response.json(); + } + if (format === "arraybuffer") { + return response.arrayBuffer(); + } + return response; + } +} + +export default CrossOriginStorageCache; diff --git a/src/engine.ts b/src/engine.ts index ddf6d2f2..865e794f 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -69,7 +69,11 @@ import { SpecifiedModelNotFoundError, ModelNotLoadedError, } from "./error"; -import { asyncLoadTokenizer } from "./cache_util"; +import { + asyncLoadTokenizer, + fetchModelArtifacts, + getArtifactCache, +} from "./cache_util"; import { EmbeddingPipeline } from "./embedding"; /** @@ -260,12 +264,7 @@ export class MLCEngine implements MLCEngineInterface { this.loadedModelIdToModelType.set(modelId, modelType); // instantiate cache - let configCache: tvmjs.ArtifactCacheTemplate; - if (this.appConfig.useIndexedDBCache) { - configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config"); - } else { - configCache = new tvmjs.ArtifactCache("webllm/config"); - } + const configCache = getArtifactCache("webllm/config", this.appConfig); // load config const configUrl = new URL("mlc-chat-config.json", modelUrl).href; @@ -281,12 +280,7 @@ export class MLCEngine implements MLCEngineInterface { this.loadedModelIdToChatConfig.set(modelId, curModelConfig); // load tvm wasm - let wasmCache: tvmjs.ArtifactCacheTemplate; - if (this.appConfig.useIndexedDBCache) { - wasmCache = new tvmjs.ArtifactIndexedDBCache("webllm/wasm"); - } else { - wasmCache = new tvmjs.ArtifactCache("webllm/wasm"); - } + const wasmCache = getArtifactCache("webllm/wasm", this.appConfig); const wasmUrl = modelRecord.model_lib; if (wasmUrl === undefined) { @@ -366,12 +360,11 @@ export class MLCEngine implements MLCEngineInterface { this.appConfig, this.logger, ); - const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache"; - await tvm.fetchTensorCache( + await fetchModelArtifacts( + tvm, modelUrl, tvm.webgpu(), - "webllm/model", - cacheType, + this.appConfig, this.reloadController?.signal, ); diff --git a/src/utils.ts b/src/utils.ts index 7c688927..2c91b1e8 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -80,6 +80,9 @@ export function areAppConfigsEqual( if (config1.useIndexedDBCache !== config2.useIndexedDBCache) { return false; } + if (config1.cacheBackend !== config2.cacheBackend) { + return false; + } // Check if both configurations have the same number of model records if (config1.model_list.length !== config2.model_list.length) { From dae0890799bf38dfbd3f8e6af14f5596c73b2224 Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Fri, 24 Oct 2025 15:36:21 -0400 Subject: [PATCH 3/4] Update READMEs to describe cross-origin storage --- examples/README.md | 3 +-- examples/cache-usage/README.md | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/README.md b/examples/README.md index 0d7dad42..766e3fc7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -46,8 +46,7 @@ These examples demonstrate various capabilities via WebLLM's OpenAI-like API. #### Others - [logit-processor](logit-processor): while `logit_bias` is supported, we additionally support stateful logit processing where users can specify their own rules. We also expose low-level API `forwardTokensAndSample()`. -- [cache-usage](cache-usage): demonstrates how WebLLM supports both the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache) and [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), and - users can pick with `appConfig.useIndexedDBCache`. Also demonstrates various cache utils such as checking +- [cache-usage](cache-usage): demonstrates how WebLLM supports multiple cache backends. Choose between the [Cache API](https://developer.mozilla.org/en-US/docs/Web/API/Cache), [IndexedDB cache](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API), or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension via `appConfig.cacheBackend`. Also demonstrates various cache utils such as checking whether a model is cached, deleting a model's weights from cache, deleting a model library wasm from cache, etc. - [simple-chat-upload](simple-chat-upload): demonstrates how to upload local models to WebLLM instead of downloading via a URL link diff --git a/examples/cache-usage/README.md b/examples/cache-usage/README.md index dab6d623..51205acc 100644 --- a/examples/cache-usage/README.md +++ b/examples/cache-usage/README.md @@ -1,9 +1,13 @@ # WebLLM Cache Usage -WebLLM supports both the Cache API and IndexedDB, which you can specify via `AppConfig.useIndexedDBCache`. -This folder provides an example on how Cache and IndexedDB Cache are used in WebLLM. We also +WebLLM supports multiple persistent cache backends. You can pick the classic Cache API, IndexedDB, or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension by +setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. (`AppConfig.useIndexedDBCache` +is still honored for backward compatibility.) +This folder provides an example on how different caches are used in WebLLM. We also demonstrate the utility cache functions such as deleting models, checking if models are in cache, etc. +> **Note:** The cross-origin backend requires Chrome's cross-origin storage experiment or the community browser extension to be installed and granted access to the domains that host your model artifacts (e.g. huggingface.co). + For more information about the two caches, see: https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser. To inspect the downloaded artifacts in your browser, open up developer console, go to application, From 700b1c5700430d1e25ac4ecfc353870485270a58 Mon Sep 17 00:00:00 2001 From: Akaash Parthasarathy Date: Wed, 29 Oct 2025 01:03:08 -0400 Subject: [PATCH 4/4] Remove backward compatibility with useIndexedDBCache flag --- examples/cache-usage/README.md | 3 +- examples/cache-usage/package.json | 6 +- src/cache_util.ts | 28 ++++++--- src/config.ts | 9 ++- src/cross_origin_storage.ts | 61 +++++++++++++++----- src/cross_origin_storage_cache.ts | 3 +- src/utils.ts | 5 +- tests/scripts/sanity_checks/sanity_checks.ts | 2 +- 8 files changed, 80 insertions(+), 37 deletions(-) diff --git a/examples/cache-usage/README.md b/examples/cache-usage/README.md index 51205acc..7db09833 100644 --- a/examples/cache-usage/README.md +++ b/examples/cache-usage/README.md @@ -1,8 +1,7 @@ # WebLLM Cache Usage WebLLM supports multiple persistent cache backends. You can pick the classic Cache API, IndexedDB, or the experimental Chrome [Cross-Origin Storage](https://github.com/explainers-by-googlers/cross-origin-storage) extension by -setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. (`AppConfig.useIndexedDBCache` -is still honored for backward compatibility.) +setting `AppConfig.cacheBackend` to `"cache"`, `"indexeddb"`, or `"cross-origin"`. This folder provides an example on how different caches are used in WebLLM. We also demonstrate the utility cache functions such as deleting models, checking if models are in cache, etc. diff --git a/examples/cache-usage/package.json b/examples/cache-usage/package.json index cebca412..dca52450 100644 --- a/examples/cache-usage/package.json +++ b/examples/cache-usage/package.json @@ -3,18 +3,18 @@ "version": "0.1.0", "private": true, "scripts": { - "start": "parcel src/cache_usage.html --port 8888", + "start": "parcel src/cache_usage.html --port 8889", "build": "parcel build src/cache_usage.html --dist-dir lib" }, "devDependencies": { "buffer": "^5.7.1", - "parcel": "^2.8.3", + "parcel": "2.8.3", "process": "^0.11.10", "tslib": "^2.3.1", "typescript": "^4.9.5", "url": "^0.11.3" }, "dependencies": { - "@mlc-ai/web-llm": "^0.2.79" + "@mlc-ai/web-llm": "file:../.." } } diff --git a/src/cache_util.ts b/src/cache_util.ts index 027e3fbf..89e20651 100644 --- a/src/cache_util.ts +++ b/src/cache_util.ts @@ -15,6 +15,26 @@ import CrossOriginStorageCache from "./cross_origin_storage_cache"; type CacheScope = "webllm/model" | "webllm/config" | "webllm/wasm"; let crossOriginUnavailableLogged = false; +let crossOriginAvailabilityWait: Promise | null = null; + +function scheduleCrossOriginFallbackWarning( + logger: (msg: string) => void, +): void { + if (crossOriginUnavailableLogged || crossOriginAvailabilityWait) { + return; + } + crossOriginAvailabilityWait = (async () => { + const availableSoon = await CrossOriginStorage.waitForAvailability(); + crossOriginAvailabilityWait = null; + if (availableSoon || crossOriginUnavailableLogged) { + return; + } + logger( + "Cross-origin storage backend is not yet available; temporarily falling back to the Cache API.", + ); + crossOriginUnavailableLogged = true; + })(); +} function shouldUseCrossOrigin(appConfig: AppConfig): boolean { return ( @@ -33,13 +53,7 @@ export function getArtifactCache( if (CrossOriginStorage.isAvailable()) { return new CrossOriginStorageCache(scope); } - // Fallback to Cache API - if (!crossOriginUnavailableLogged) { - logger( - "Cross-origin storage backend requested but unavailable; falling back to Cache API.", - ); - crossOriginUnavailableLogged = true; - } + scheduleCrossOriginFallbackWarning(logger); } if (backend === "indexeddb") { return new tvmjs.ArtifactIndexedDBCache(scope); diff --git a/src/config.ts b/src/config.ts index eaff9e02..2f76ca76 100644 --- a/src/config.ts +++ b/src/config.ts @@ -270,8 +270,8 @@ export interface ModelRecord { * passed to the load. * * @param model_list: models to be used. - * @param useIndexedDBCache: if true, will use IndexedDBCache to cache models and other artifacts. - * If false or unspecified, will use the Cache API. For more information of the two, see: + * @param cacheBackend: the backend to use for caching models and other artifacts. + * If unspecified, will use the Cache API. For more information, see: * https://developer.mozilla.org/en-US/docs/Web/API/Storage_API/Storage_quotas_and_eviction_criteria#what_technologies_store_data_in_the_browser * * @note Note that the Cache API is more well-tested in WebLLM as of now. @@ -280,7 +280,6 @@ export type CacheBackend = "cache" | "indexeddb" | "cross-origin"; export interface AppConfig { model_list: Array; - useIndexedDBCache?: boolean; cacheBackend?: CacheBackend; } @@ -288,7 +287,7 @@ export function getCacheBackend(appConfig: AppConfig): CacheBackend { if (appConfig.cacheBackend !== undefined) { return appConfig.cacheBackend; } - return appConfig.useIndexedDBCache ? "indexeddb" : "cache"; + return "cache"; } /** @@ -320,7 +319,7 @@ export const functionCallingModelIds = [ * current WebLLM npm version. */ export const prebuiltAppConfig: AppConfig = { - useIndexedDBCache: false, + cacheBackend: "cache", model_list: [ // Llama-3.2 { diff --git a/src/cross_origin_storage.ts b/src/cross_origin_storage.ts index 9d4c2659..d2d28acc 100644 --- a/src/cross_origin_storage.ts +++ b/src/cross_origin_storage.ts @@ -1,5 +1,17 @@ const HASH_ALGORITHM = "SHA-256"; const HASH_MATCH_REGEX = /[A-Fa-f0-9]{64}/; +const AVAILABILITY_POLL_INTERVAL_MS = 100; +const DEFAULT_AVAILABILITY_TIMEOUT_MS = 3000; +const HASH_CACHE_SYMBOL = Symbol.for("mlc.crossOriginStorage.hashCache"); + +const globalScope = globalThis as Record; +if (!globalScope[HASH_CACHE_SYMBOL]) { + globalScope[HASH_CACHE_SYMBOL] = new Map(); +} +const GLOBAL_HASH_CACHE = globalScope[HASH_CACHE_SYMBOL] as Map< + string, + CrossOriginHashDescriptor +>; export interface CrossOriginHashDescriptor { algorithm: string; @@ -16,7 +28,6 @@ interface CrossOriginStorageAPI { descriptors: CrossOriginHashDescriptor[], options?: { create?: boolean }, ): Promise; - removeFileHandles?(descriptors: CrossOriginHashDescriptor[]): Promise; } type RequestLike = string | URL | Request | { url?: string }; @@ -25,13 +36,16 @@ declare global { interface Navigator { crossOriginStorage?: CrossOriginStorageAPI; } + interface WorkerNavigator { + crossOriginStorage?: CrossOriginStorageAPI; + } } export default class CrossOriginStorage { private hashCache: Map; constructor() { - this.hashCache = new Map(); + this.hashCache = GLOBAL_HASH_CACHE; } static isAvailable(): boolean { @@ -42,6 +56,35 @@ export default class CrossOriginStorage { ); } + static async waitForAvailability( + timeoutMs: number = DEFAULT_AVAILABILITY_TIMEOUT_MS, + ): Promise { + if (CrossOriginStorage.isAvailable()) { + return true; + } + if (typeof navigator === "undefined") { + return false; + } + if (typeof setTimeout === "undefined") { + return false; + } + return new Promise((resolve) => { + const deadline = Date.now() + timeoutMs; + const tick = () => { + if (CrossOriginStorage.isAvailable()) { + resolve(true); + return; + } + if (Date.now() >= deadline) { + resolve(false); + return; + } + setTimeout(tick, AVAILABILITY_POLL_INTERVAL_MS); + }; + setTimeout(tick, AVAILABILITY_POLL_INTERVAL_MS); + }); + } + async match(request: RequestLike): Promise { const url = this.normalizeRequest(request); const hash = await this.resolveHashDescriptor(url); @@ -85,16 +128,8 @@ export default class CrossOriginStorage { } async delete(request: RequestLike): Promise { - const url = this.normalizeRequest(request); - const hash = await this.resolveHashDescriptor(url); - if (!hash) { - return; - } - const api = this.getApi(); - if (api && typeof api.removeFileHandles === "function") { - await api.removeFileHandles([hash]); - } - this.hashCache.delete(url); + // Currently no delete API provided by Cross-Origin Storage Extension + return; } private getApi(): CrossOriginStorageAPI | undefined { @@ -145,7 +180,7 @@ export default class CrossOriginStorage { if (metadataHash) { return metadataHash; } - if (/\/resolve\/main\//.test(url)) { + if (/\/resolve\//.test(url)) { const pointerHash = await this.extractHashFromPointer(url); if (pointerHash) { return pointerHash; diff --git a/src/cross_origin_storage_cache.ts b/src/cross_origin_storage_cache.ts index 4bfab8f3..211f0424 100644 --- a/src/cross_origin_storage_cache.ts +++ b/src/cross_origin_storage_cache.ts @@ -67,8 +67,7 @@ export class CrossOriginStorageCache implements tvmjs.ArtifactCacheTemplate { } async deleteInCache(_url: string): Promise { - // no delete API currently provided by Cross-Origin Storage - return; + await this.storage.delete(_url); } private async responseToStoreType( diff --git a/src/utils.ts b/src/utils.ts index 2c91b1e8..17a86925 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -76,10 +76,7 @@ export function areAppConfigsEqual( return config1 === config2; } - // Check if both configurations have the same IndexedDB cache usage - if (config1.useIndexedDBCache !== config2.useIndexedDBCache) { - return false; - } + // Check if both configurations have the same cache backend if (config1.cacheBackend !== config2.cacheBackend) { return false; } diff --git a/tests/scripts/sanity_checks/sanity_checks.ts b/tests/scripts/sanity_checks/sanity_checks.ts index da842353..2f96e051 100644 --- a/tests/scripts/sanity_checks/sanity_checks.ts +++ b/tests/scripts/sanity_checks/sanity_checks.ts @@ -157,7 +157,7 @@ async function testLogprobs(modelId: string, appConfig: webllm.AppConfig) { async function main() { const modelId = "Qwen3-0.6B-q0f32-MLC"; const appConfig = webllm.prebuiltAppConfig; - appConfig.useIndexedDBCache = true; + appConfig.cacheBackend = "indexeddb"; setLabel("gpu-test-label", "Running tests..."); let passed = 0, total = 0;