diff --git a/src/utils/CrossOriginStorage.js b/src/utils/CrossOriginStorage.js new file mode 100644 index 000000000..d3472a6a8 --- /dev/null +++ b/src/utils/CrossOriginStorage.js @@ -0,0 +1,76 @@ +const HASH_ALGORITHM = "SHA-256"; + +class CrossOriginStorage { + static isAvailable = () => + typeof navigator !== "undefined" && "crossOriginStorage" in navigator; + + match = async (request) => { + const hashValue = await this._getFileHash(request); + if (!hashValue) { + return undefined; + } + const hash = { algorithm: HASH_ALGORITHM, value: hashValue }; + try { + // @ts-expect-error + const [handle] = await navigator.crossOriginStorage.requestFileHandles([ + hash, + ]); + const blob = await handle.getFile(); + return new Response(blob); + } catch (err) { + return undefined; + } + }; + put = async (request, response) => { + const blob = await response.blob(); + const hash = await this._getBlobHash(blob); + // @ts-expect-error + const [handle] = await navigator.crossOriginStorage.requestFileHandles( + [hash], + { create: true }, + ); + const writableStream = await handle.createWritable(); + await writableStream.write(blob); + await writableStream.close(); + }; + + // Gets the SHA-256 hash for large resources as per + // https://huggingface.co/docs/hub/en/storage-backends#xet. + _getFileHash = async (url) => { + if (/\/resolve\/main\/onnx\//.test(url)) { + const rawUrl = url.replace(/\/resolve\//, "/raw/"); + const text = await fetch(rawUrl).then((response) => response.text()); + if (!text.includes("oid sha256:")) { + return null; + } + return text.replace(/.*?\n^oid sha256:(\w+)\n.*?$/gm, "$1") || null; + } + return null; + }; + + _getBlobHash = async (blob) => { + const hashAlgorithmIdentifier = "SHA-256"; + + // Get the contents of the blob as binary data contained in an ArrayBuffer. + const arrayBuffer = await blob.arrayBuffer(); + + // Hash the arrayBuffer using SHA-256. + const hashBuffer = await crypto.subtle.digest( + hashAlgorithmIdentifier, + arrayBuffer, + ); + + // Convert the ArrayBuffer to a hex string. + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const hashHex = hashArray + .map((byte) => byte.toString(16).padStart(2, "0")) + .join(""); + + return { + algorithm: hashAlgorithmIdentifier, + value: hashHex, + }; + }; +} + +export default CrossOriginStorage; diff --git a/src/utils/hub.js b/src/utils/hub.js index a56d1bf7a..43bda7620 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -10,6 +10,7 @@ import path from 'node:path'; import { apis, env } from '../env.js'; import { dispatchCallback } from './core.js'; +import CrossOriginStorage from './CrossOriginStorage.js' /** * @typedef {boolean|number} ExternalData Whether to load the model using the external data format (used for models >= 2GB in size). @@ -479,6 +480,10 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti filename ); + if(CrossOriginStorage.isAvailable()){ + cache = new CrossOriginStorage(); + } + /** @type {string} */ let cacheKey; const proposedCacheKey = cache instanceof FileCache