Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ export class PretrainedConfig {
cache_dir = null,
local_files_only = false,
revision = 'main',
abort_signal = undefined,
} = {}) {
if (config && !(config instanceof PretrainedConfig)) {
config = new PretrainedConfig(config);
Expand All @@ -378,6 +379,7 @@ export class PretrainedConfig {
cache_dir,
local_files_only,
revision,
abort_signal,
})
return new this(data);
}
Expand Down
4 changes: 4 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ export class PreTrainedModel extends Callable {
local_files_only = false,
revision = 'main',
model_file_name = null,
abort_signal = undefined,
subfolder = 'onnx',
device = null,
dtype = null,
Expand All @@ -999,6 +1000,7 @@ export class PreTrainedModel extends Callable {
dtype,
use_external_data_format,
session_options,
abort_signal,
}

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
Expand Down Expand Up @@ -6999,6 +7001,7 @@ export class PretrainedMixin {
dtype = null,
use_external_data_format = null,
session_options = {},
abort_signal = undefined,
} = {}) {

const options = {
Expand All @@ -7013,6 +7016,7 @@ export class PretrainedMixin {
dtype,
use_external_data_format,
session_options,
abort_signal,
}
options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

Expand Down
8 changes: 4 additions & 4 deletions src/models/mgp_str/processing_mgp_str.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ export class MgpstrProcessor extends Processor {
}
}
/** @type {typeof Processor.from_pretrained} */
static async from_pretrained(...args) {
const base = await super.from_pretrained(...args);
static async from_pretrained(pretrained_model_name_or_path, options) {
const base = await super.from_pretrained(pretrained_model_name_or_path, options);

// Load Transformers.js-compatible versions of the BPE and WordPiece tokenizers
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2") // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased") // google-bert/bert-base-uncased
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2", { abort_signal: options?.abort_signal }) // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased", { abort_signal: options?.abort_signal }) // google-bert/bert-base-uncased

// Update components
base.components = {
Expand Down
12 changes: 9 additions & 3 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ export class Pipeline extends Callable {
* @param {PreTrainedModel} [options.model] The model used by the pipeline.
* @param {PreTrainedTokenizer} [options.tokenizer=null] The tokenizer used by the pipeline (if any).
* @param {Processor} [options.processor=null] The processor used by the pipeline (if any).
* @param {AbortSignal} [options.abort_signal=undefined] An optional AbortSignal to cancel the request.
*/
constructor({ task, model, tokenizer = null, processor = null }) {
constructor({ task, model, tokenizer = null, processor = null, abort_signal = undefined }) {
super();
this.task = task;
this.model = model;
this.tokenizer = tokenizer;
this.processor = processor;
this.abort_signal = abort_signal;
}

/** @type {DisposeType} */
Expand Down Expand Up @@ -210,6 +212,7 @@ export class Pipeline extends Callable {
* @property {PreTrainedModel} model The model used by the pipeline.
* @property {PreTrainedTokenizer} tokenizer The tokenizer used by the pipeline.
* @property {Processor} processor The processor used by the pipeline.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*
* @typedef {ModelTokenizerProcessorConstructorArgs} TextAudioPipelineConstructorArgs An object used to instantiate a text- and audio-based pipeline.
* @typedef {ModelTokenizerProcessorConstructorArgs} TextImagePipelineConstructorArgs An object used to instantiate a text- and image-based pipeline.
Expand Down Expand Up @@ -2776,17 +2779,18 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi

async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {


// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' , abort_signal : this.abort_signal });
}

// Load speaker embeddings as Float32Array from path/URL
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
// Load from URL with fetch
speaker_embeddings = new Float32Array(
await (await fetch(speaker_embeddings)).arrayBuffer()
await (await fetch(speaker_embeddings, { signal: this.abort_signal })).arrayBuffer()
);
}

Expand Down Expand Up @@ -3301,6 +3305,7 @@ export async function pipeline(
dtype = null,
model_file_name = null,
session_options = {},
abort_signal = undefined,
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3331,6 +3336,7 @@ export async function pipeline(
dtype,
model_file_name,
session_options,
abort_signal,
}

const classes = new Map([
Expand Down
4 changes: 4 additions & 0 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const info = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -2691,6 +2692,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only,
revision,
legacy,
abort_signal,
})

// @ts-ignore
Expand Down Expand Up @@ -4351,6 +4353,7 @@ export class AutoTokenizer {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -4360,6 +4363,7 @@ export class AutoTokenizer {
local_files_only,
revision,
legacy,
abort_signal,
})

// Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
Expand Down
44 changes: 29 additions & 15 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { dispatchCallback } from './core.js';
* @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*/

/**
Expand Down Expand Up @@ -58,9 +59,11 @@ class FileResponse {
/**
* Creates a new `FileResponse` object.
* @param {string|URL} filePath
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
*/
constructor(filePath) {
constructor(filePath, abort_signal) {
this.filePath = filePath;
this.abort_signal = abort_signal;
this.headers = new Headers();

this.exists = fs.existsSync(filePath);
Expand All @@ -79,9 +82,16 @@ class FileResponse {
self.arrayBuffer().then(buffer => {
controller.enqueue(new Uint8Array(buffer));
controller.close();
})
}).catch(error => {
controller.error(error);
});

abort_signal?.addEventListener('abort', () => {
controller.error(new Error('Request aborted'));
});
}
});

} else {
this.status = 404;
this.statusText = 'Not Found';
Expand All @@ -105,7 +115,7 @@ class FileResponse {
* @returns {FileResponse} A new FileResponse object with the same properties as the current object.
*/
clone() {
let response = new FileResponse(this.filePath);
let response = new FileResponse(this.filePath, this.abort_signal);
response.exists = this.exists;
response.status = this.status;
response.statusText = this.statusText;
Expand Down Expand Up @@ -185,12 +195,13 @@ function isValidUrl(string, protocols = null, validHosts = null) {
* Helper function to get a file, using either the Fetch API or FileSystem API.
*
* @param {URL|string} urlOrPath The URL/path of the file to get.
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API).
*/
export async function getFile(urlOrPath) {
export async function getFile(urlOrPath, abort_signal) {

if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) {
return new FileResponse(urlOrPath);
return new FileResponse(urlOrPath, abort_signal);

} else if (typeof process !== 'undefined' && process?.release?.name === 'node') {
const IS_CI = !!process.env?.TESTING_REMOTELY;
Expand All @@ -210,12 +221,12 @@ export async function getFile(urlOrPath) {
headers.set('Authorization', `Bearer ${token}`);
}
}
return fetch(urlOrPath, { headers });
return fetch(urlOrPath, { headers, signal: abort_signal });
} else {
// Running in a browser-environment, so we use default headers
// NOTE: We do not allow passing authorization headers in the browser,
// since this would require exposing the token to the client.
return fetch(urlOrPath);
return fetch(urlOrPath, { signal: abort_signal });
}
}

Expand Down Expand Up @@ -263,13 +274,15 @@ class FileCache {

/**
* Checks whether the given request is in the cache.
* @param {string} request
* @param {string} request
* @param {Object} options An object containing the following properties:
* @param {AbortSignal} [options.abort_signal] An optional AbortSignal to cancel the request.
* @returns {Promise<FileResponse | undefined>}
*/
async match(request) {
async match(request, { abort_signal = undefined } = {}) {

let filePath = path.join(this.path, request);
let file = new FileResponse(filePath);
let file = new FileResponse(filePath, abort_signal);

if (file.exists) {
return file;
Expand Down Expand Up @@ -309,13 +322,14 @@ class FileCache {
/**
*
* @param {FileCache|Cache} cache The cache to search
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @param {string[]} names The names of the item to search for
* @returns {Promise<FileResponse|Response|undefined>} The item from the cache, or undefined if not found.
*/
async function tryCache(cache, ...names) {
async function tryCache(cache, abort_signal, ...names) {
for (let name of names) {
try {
let result = await cache.match(name);
let result = await cache.match(name, {abort_signal});
if (result) return result;
} catch (e) {
continue;
Expand Down Expand Up @@ -433,7 +447,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
// 1. We first try to get from cache using the local path. In some environments (like deno),
// non-URL cache keys are not allowed. In these cases, `response` will be undefined.
// 2. If no response is found, we try to get from cache using the remote URL or file system cache.
response = await tryCache(cache, localPath, proposedCacheKey);
response = await tryCache(cache, options?.abort_signal, localPath, proposedCacheKey);
}

const cacheHit = response !== undefined;
Expand All @@ -447,7 +461,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
const isURL = isValidUrl(requestURL, ['http:', 'https:']);
if (!isURL) {
try {
response = await getFile(localPath);
response = await getFile(localPath, options?.abort_signal);
cacheKey = localPath; // Update the cache key to be the local path
} catch (e) {
// Something went wrong while trying to get the file locally.
Expand Down Expand Up @@ -479,7 +493,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
}

// File not found locally, so we try to download it from the remote server
response = await getFile(remoteURL);
response = await getFile(remoteURL, options?.abort_signal);

if (response.status !== 200) {
return handleError(response.status, remoteURL, fatal);
Expand Down