diff --git a/src/configs.js b/src/configs.js index e5c660a73..2932b40c2 100644 --- a/src/configs.js +++ b/src/configs.js @@ -179,6 +179,7 @@ function getNormalizedConfig(config) { // Encoder-decoder models case 't5': + case 'chronos2': case 'mt5': case 'longt5': mapping['num_decoder_layers'] = 'num_decoder_layers'; diff --git a/src/models.js b/src/models.js index 2f7fd569c..0153d81e6 100644 --- a/src/models.js +++ b/src/models.js @@ -3009,6 +3009,57 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Chronos2 models +/** + * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + */ +export class Chronos2PreTrainedModel extends PreTrainedModel { + forward_params = [ + 'context', + 'group_ids', + 'attention_mask', + ]; +}; + +/** + * The Chronos-2 Model for time series forecasting. + * + * Chronos-2 is a family of pretrained time series forecasting models based on T5. + * It uses a patching mechanism to convert time series into tokens and predicts + * multiple quantiles for probabilistic forecasting. + * + * **Example:** Load and run a Chronos-2 model for forecasting. + * + * ```javascript + * import { Chronos2ForForecasting } from '@huggingface/transformers'; + * + * const model = await Chronos2ForForecasting.from_pretrained('amazon/chronos-2-small'); + * + * // Prepare time series input + * const context = new Float32Array([1.0, 2.0, 3.0, 4.0, ...]); // Your historical data + * const inputs = { + * context: context, + * group_ids: new BigInt64Array([0]), // Group ID for cross-learning + * attention_mask: new Float32Array(context.length).fill(1.0), + * }; + * + * // Generate forecasts + * const { quantile_preds } = await model(inputs); + * // Returns quantile predictions: [batch_size, num_quantiles, prediction_length] + * ``` + */ +export class Chronos2Model extends Chronos2PreTrainedModel { } + +/** + * Chronos2 Model with a forecasting head for time series prediction. + * + * This model outputs quantile predictions for probabilistic forecasting. + */ +export class Chronos2ForForecasting extends Chronos2PreTrainedModel { } +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // LONGT5 models /** @@ -7839,6 +7890,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ ['t5', ['T5Model', T5Model]], + ['chronos2', ['Chronos2Model', Chronos2Model]], ['longt5', ['LongT5Model', LongT5Model]], ['mt5', ['MT5Model', MT5Model]], ['bart', ['BartModel', BartModel]], @@ -7957,6 +8009,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([ ['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]], + ['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]], ['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]], ['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]], ['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]], @@ -8226,6 +8279,10 @@ const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([ ['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]], ]) +const MODEL_FOR_FORECASTING_MAPPING_NAMES = new Map([ + ['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]], +]) + const MODEL_CLASS_TYPE_MAPPING = [ // MODEL_MAPPING_NAMES: [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly], @@ -8263,6 +8320,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_FORECASTING_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], // Custom: [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -8565,6 +8623,30 @@ export class AutoModelForAudioTextToText extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES]; } +/** + * Helper class which is used to instantiate time series forecasting models with the `from_pretrained` function. + * + * @example + * const model = await AutoModelForForecasting.from_pretrained('amazon/chronos-2-small'); + */ +export class AutoModelForForecasting extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_FORECASTING_MAPPING_NAMES]; + + /** @type {typeof PreTrainedModel.from_pretrained} */ + static async from_pretrained(pretrained_model_name_or_path, options = {}) { + // First, load the config to check if it has chronos_config + const config = options.config || await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); + + // If model has chronos_config, route to Chronos2ForForecasting regardless of model_type + if (config.chronos_config) { + return await Chronos2ForForecasting.from_pretrained(pretrained_model_name_or_path, { ...options, config }); + } + + // Otherwise, use the standard mapping-based routing + return await super.from_pretrained(pretrained_model_name_or_path, { ...options, config }); + } +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/pipelines.js b/src/pipelines.js index 6c84403e2..23b593ed6 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -41,6 +41,7 @@ import { AutoModelForImageToImage, AutoModelForDepthEstimation, AutoModelForImageFeatureExtraction, + AutoModelForForecasting, PreTrainedModel, } from './models.js'; import { @@ -3057,6 +3058,497 @@ export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipe } } +/** + * @typedef {Object} Chronos2Input + * @property {number[]|Float32Array} target - Target time series + * @property {Object.} [past_covariates] - Historical covariates available during context + * @property {Object.} [future_covariates] - Known future covariates for forecast horizon + * + * @typedef {Object} Chronos2PipelineOptions Parameters specific to time series forecasting pipelines. + * @property {number} [prediction_length=16] The number of time steps to predict. + * @property {number[]} [quantile_levels=[0.1, 0.5, 0.9]] The quantile levels to predict. + * @property {boolean} [predict_batches_jointly=false] Enable cross-learning across batch items (shares information between series). + * @property {number} [batch_size=100] Batch size for joint prediction (recommended: ~100). + * + * @callback Chronos2PipelineCallback Forecast time series. + * @param {number[]|number[][]|Float32Array|Float32Array[]|Chronos2Input|Chronos2Input[]} inputs One or more time series to forecast. + * Each input can be: a 1D array (simple univariate), an object with target and optional covariates, or an array of either. + * @param {Chronos2PipelineOptions} [options] The options to use for forecasting. + * @returns {Promise} The forecasted quantiles for each input time series. + * + * @typedef {Object} Chronos2Output + * @property {number[][]} forecast Array of shape [prediction_length, num_quantiles] containing the forecasted quantiles. + * @property {number[]} quantile_levels The quantile levels corresponding to the forecast columns. + * + * @typedef {Object} Chronos2PipelineConstructorArgs + * @property {string} task The task of the pipeline. + * @property {PreTrainedModel} model The model used by the pipeline. + * + * @typedef {Chronos2PipelineConstructorArgs & Chronos2PipelineCallback & Disposable} Chronos2PipelineType + */ + +/** + * Time series forecasting pipeline using Chronos-2 models. + * + * Chronos-2 is a family of pretrained time series forecasting models based on T5. + * It uses a patching-based approach with instance normalization and arcsinh transformation. + * + * **Example:** Forecast M4 hourly energy consumption data. + * ```javascript + * const forecaster = await pipeline('time-series-forecasting', 'amazon/chronos-2-small'); + * + * // Historical time series data + * const timeSeries = [605, 586, 586, 559, 511, 487, 484, 458, ...]; // 100 timesteps + * + * // Generate 16-step forecast with quantiles + * const output = await forecaster(timeSeries, { + * prediction_length: 16, + * quantile_levels: [0.1, 0.5, 0.9], // 10th, 50th (median), 90th percentiles + * }); + * + * // Output format: { forecast: [[t1_q1, t1_q2, t1_q3], [t2_q1, t2_q2, t2_q3], ...], quantile_levels: [0.1, 0.5, 0.9] } + * console.log('Median forecast:', output.forecast.map(row => row[1])); // Extract median (50th percentile) + * ``` + * + * **Example:** Forecast with future covariates (known future information). + * ```javascript + * const forecaster = await pipeline('time-series-forecasting', 'amazon/chronos-2-small'); + * + * // Forecast with covariates (e.g., energy prices with load and weather forecasts) + * const output = await forecaster({ + * target: [42.5, 45.2, 48.1, 46.8, 44.3, 41.2, ...], // Historical prices + * future_covariates: { + * 'temperature': [20, 21, 22, 23, 24, 25, 26, 27], // Known future temperatures + * 'day_of_week': ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun', 'Mon'], // Categorical + * } + * }, { + * prediction_length: 8, + * }); + * + * // Covariates are stacked in batch dimension for cross-learning + * // Categorical values are automatically encoded to indices + * ``` + * + * **Example:** Batch forecasting for multiple time series. + * ```javascript + * const forecaster = await pipeline('time-series-forecasting', 'amazon/chronos-2-tiny'); + * + * const batch = [ + * [100, 110, 105, 115, 120, ...], // Series 1 + * [50, 55, 52, 58, 60, ...], // Series 2 + * ]; + * + * const outputs = await forecaster(batch); + * // Returns array of forecasts, one per input series + * ``` + */ +/** + * Time series forecasting pipeline supporting multiple forecasting models. + * Currently supports Chronos-2 and compatible architectures. + * + * @example + * const forecaster = await pipeline('time-series-forecasting', 'kashif/chronos-2-onnx'); + * const result = await forecaster(data, { prediction_length: 24 }); + */ +export class TimeSeriesForecastingPipeline extends (/** @type {new (options: Chronos2PipelineConstructorArgs) => Chronos2PipelineType} */ (Pipeline)) { + /** + * Create a new TimeSeriesForecastingPipeline. + * @param {Chronos2PipelineConstructorArgs} options An object used to instantiate the pipeline. + */ + constructor(options) { + super(options); + + // Get model config for default parameters + this.config = this.model.config; + + // Support different forecasting model configs + // Chronos-2 uses chronos_config, others might use different structures + this.model_config = this.config.chronos_config || this.config; + + // Default parameters from config (model-agnostic) + this.default_prediction_length = this.model_config.prediction_length || + this.model_config.prediction_horizon || 16; + this.patch_size = this.model_config.input_patch_size || + this.model_config.patch_size || 16; + this.output_patch_size = this.model_config.output_patch_size || this.patch_size; + this.quantile_levels = this.model_config.quantiles || + this.model_config.quantile_levels || + [0.1, 0.5, 0.9]; + + // Covariate vocabulary for categorical encoding + this.covariate_vocab = new Map(); + } + + /** + * Pad time series to be divisible by patch_size using repeat-padding. + * IMPORTANT: Uses repeat-padding (not NaN) to preserve normalization statistics! + * + * @param {Float32Array} timeSeries - The input time series + * @returns {Float32Array} - Padded time series + * @private + */ + _padTimeSeries(timeSeries) { + const length = timeSeries.length; + const remainder = length % this.patch_size; + + if (remainder === 0) return timeSeries; + + const paddingSize = this.patch_size - remainder; + const padded = new Float32Array(length + paddingSize); + + // Pad at the beginning by repeating the first value + for (let i = 0; i < paddingSize; i++) { + padded[i] = timeSeries[0]; + } + + // Copy original data + padded.set(timeSeries, paddingSize); + + return padded; + } + + /** + * Encode categorical covariate values to numerical indices. + * @param {string} covariateName - Name of the covariate + * @param {Array} values - Array of values (strings or numbers) + * @returns {Float32Array} - Encoded values + * @private + */ + _encodeCategorical(covariateName, values) { + if (!this.covariate_vocab.has(covariateName)) { + this.covariate_vocab.set(covariateName, new Map()); + } + + const vocab = this.covariate_vocab.get(covariateName); + const encoded = new Float32Array(values.length); + + for (let i = 0; i < values.length; i++) { + const value = values[i]; + if (typeof value === 'string') { + // Encode string to index + if (!vocab.has(value)) { + vocab.set(value, vocab.size); + } + encoded[i] = vocab.get(value); + } else { + // Already numerical + encoded[i] = value; + } + } + + return encoded; + } + + /** + * Process covariates: encode categorical values, align length, and concatenate. + * @param {Object.} covariates - Dictionary of covariate arrays + * @param {number} targetLength - Target length for alignment + * @returns {Float32Array|null} - Concatenated covariates or null + * @private + */ + _processCovariates(covariates, targetLength) { + if (!covariates || Object.keys(covariates).length === 0) { + return null; + } + + const covariateArrays = []; + + for (const [name, values] of Object.entries(covariates)) { + if (values === null || values === undefined) continue; + + // Encode categorical values + const encoded = this._encodeCategorical(name, values); + + // Pad to match target length + let padded; + if (encoded.length < targetLength) { + padded = new Float32Array(targetLength); + // Pad at start with first value + for (let i = 0; i < targetLength - encoded.length; i++) { + padded[i] = encoded[0]; + } + padded.set(encoded, targetLength - encoded.length); + } else if (encoded.length > targetLength) { + // Truncate from start + padded = encoded.slice(encoded.length - targetLength); + } else { + padded = encoded; + } + + covariateArrays.push(padded); + } + + if (covariateArrays.length === 0) return null; + + // Return array of covariate series (will be stacked in batch dimension) + return covariateArrays; + } + + /** + * Process a single input (can be array or object with covariates). + * @param {number[]|Float32Array|Chronos2Input} input - Input time series or object + * @param {number} prediction_length - Forecast horizon length + * @returns {Object} - Processed input with target, covariates, and metadata + * @private + */ + _processInput(input, prediction_length) { + let target, past_covariates, future_covariates; + + if (Array.isArray(input) || input instanceof Float32Array) { + // Simple array input + target = input instanceof Float32Array ? input : new Float32Array(input); + past_covariates = null; + future_covariates = null; + } else { + // Object with target and covariates + target = input.target instanceof Float32Array ? + input.target : new Float32Array(input.target); + past_covariates = input.past_covariates || null; + future_covariates = input.future_covariates || null; + } + + // Pad target + const paddedTarget = this._padTimeSeries(target); + const contextLength = paddedTarget.length; + + // Process past covariates (align with context length) + const processedPastCovariates = this._processCovariates( + past_covariates, + contextLength + ); + + // Process future covariates (align with padded prediction length) + // Future covariates must be padded to num_output_patches * output_patch_size + const output_patch_size = this.output_patch_size; + const num_output_patches = Math.ceil(prediction_length / output_patch_size); + const future_covariates_length = num_output_patches * output_patch_size; + + const processedFutureCovariates = this._processCovariates( + future_covariates, + future_covariates_length + ); + + return { + target: paddedTarget, + past_covariates: processedPastCovariates, + future_covariates: processedFutureCovariates, + context_length: contextLength, + }; + } + + /** @type {Chronos2PipelineCallback} */ + async _call(inputs, { + prediction_length = this.default_prediction_length, + quantile_levels = this.quantile_levels, + predict_batches_jointly = false, + batch_size = 100, + } = {}) { + // Detect if input is batched + const isBatched = Array.isArray(inputs) && + (Array.isArray(inputs[0]) || typeof inputs[0] === 'object'); + const inputsArray = isBatched ? inputs : [inputs]; + + // For joint prediction, use cross-learning + if (predict_batches_jointly && inputsArray.length > 1) { + return this._callJoint(inputsArray, { + prediction_length, + quantile_levels, + batch_size, + }); + } + + // Process each input independently + const results = []; + + for (const input of inputsArray) { + const processed = this._processInput(input, prediction_length); + + // Stack target + covariates in batch dimension (like PyTorch does) + const output_patch_size = this.output_patch_size; + const num_output_patches = Math.ceil(prediction_length / output_patch_size); + const future_covariates_length = num_output_patches * output_patch_size; + + // Build batch: [target, covariate1, covariate2, ...] + const contextBatch = [processed.target]; + const futureCovariatesBatch = [new Float32Array(future_covariates_length)]; // Target future (NaN → zeros) + + if (processed.future_covariates && Array.isArray(processed.future_covariates)) { + // Add each covariate as a separate series in the batch + for (const covariate of processed.future_covariates) { + // Context: pad covariate to match context length (repeat last value) + const contextCovariate = new Float32Array(processed.context_length); + for (let i = 0; i < processed.context_length; i++) { + contextCovariate[i] = covariate[Math.min(i, covariate.length - 1)]; + } + contextBatch.push(contextCovariate); + futureCovariatesBatch.push(covariate); + } + } + + const batchSize = contextBatch.length; + + // Flatten batch tensors + const contextFlat = new Float32Array(batchSize * processed.context_length); + const futureFlat = new Float32Array(batchSize * future_covariates_length); + const attentionFlat = new Float32Array(batchSize * processed.context_length).fill(1.0); + const groupIds = new BigInt64Array(batchSize).fill(0n); // All in same group + + for (let b = 0; b < batchSize; b++) { + contextFlat.set(contextBatch[b], b * processed.context_length); + futureFlat.set(futureCovariatesBatch[b], b * future_covariates_length); + } + + const model_inputs = { + context: new Tensor('float32', contextFlat, [batchSize, processed.context_length]), + group_ids: new Tensor('int64', groupIds, [batchSize]), + attention_mask: new Tensor('float32', attentionFlat, [batchSize, processed.context_length]), + future_covariates: new Tensor('float32', futureFlat, [batchSize, future_covariates_length]), + num_output_patches: new Tensor('int64', new BigInt64Array([BigInt(num_output_patches)]), []), + }; + + // Run inference + const outputs = await this.model(model_inputs); + const quantile_preds = outputs.quantile_preds; + const [batch_size_out, num_quantiles, pred_length] = quantile_preds.dims; + + // Extract forecast from first batch element only (target series) + // Remaining batch elements are covariate predictions (which we ignore) + const forecast = []; + const quantile_indices = quantile_levels.map(q => { + const idx = this.quantile_levels.indexOf(q); + if (idx === -1) { + throw new Error(`Quantile level ${q} not found`); + } + return idx; + }); + + const data = quantile_preds.data; + const batch_offset = 0; // First batch element (target) + const batch_stride = num_quantiles * pred_length; + + // Truncate to requested prediction_length (model may generate more due to patching) + const actual_length = Math.min(prediction_length, pred_length); + for (let t = 0; t < actual_length; t++) { + const row = quantile_indices.map(qi => + data[batch_offset * batch_stride + qi * pred_length + t] + ); + forecast.push(row); + } + + results.push({ + forecast, + quantile_levels, + }); + } + + return isBatched ? results : results[0]; + } + + /** + * Joint prediction with cross-learning. + * @param {Array} inputs - Array of inputs + * @param {Object} options - Prediction options + * @returns {Promise} - Array of forecasts + * @private + */ + async _callJoint(inputs, { + prediction_length, + quantile_levels, + batch_size, + }) { + const allResults = []; + + // Process inputs in batches + for (let i = 0; i < inputs.length; i += batch_size) { + const batchInputs = inputs.slice(i, Math.min(i + batch_size, inputs.length)); + const batchSize = batchInputs.length; + + // Process all inputs in batch + const processed = batchInputs.map(inp => + this._processInput(inp, prediction_length) + ); + + // Find max context length for padding + const maxLength = Math.max(...processed.map(p => p.context_length)); + + // Concatenate batch with padding + const batchContext = new Float32Array(batchSize * maxLength); + const batchMask = new Float32Array(batchSize * maxLength); + const batchGroupIds = new BigInt64Array(batchSize); + + for (let b = 0; b < batchSize; b++) { + const proc = processed[b]; + const padSize = maxLength - proc.context_length; + + // Pad at beginning + for (let j = 0; j < padSize; j++) { + batchContext[b * maxLength + j] = proc.target[0]; + batchMask[b * maxLength + j] = 0.0; // Masked + } + + // Copy actual data + for (let j = 0; j < proc.context_length; j++) { + batchContext[b * maxLength + padSize + j] = proc.target[j]; + batchMask[b * maxLength + padSize + j] = 1.0; + } + + // Same group ID for joint prediction + batchGroupIds[b] = 0n; + } + + // Run batch inference + // Calculate padded future_covariates length + const output_patch_size = this.output_patch_size; + const num_output_patches = Math.ceil(prediction_length / output_patch_size); + const future_covariates_length = num_output_patches * output_patch_size; + + const model_inputs = { + context: new Tensor('float32', batchContext, [batchSize, maxLength]), + group_ids: new Tensor('int64', batchGroupIds, [batchSize]), + attention_mask: new Tensor('float32', batchMask, [batchSize, maxLength]), + // Provide zeros for future_covariates (required by some models) + future_covariates: new Tensor('float32', + new Float32Array(batchSize * future_covariates_length), + [batchSize, future_covariates_length] + ), + num_output_patches: new Tensor('int64', new BigInt64Array([BigInt(num_output_patches)]), []), + }; + + const outputs = await this.model(model_inputs); + const quantile_preds = outputs.quantile_preds; + const [_, num_quantiles, pred_length] = quantile_preds.dims; + + // Extract results for each item in batch + const quantile_indices = quantile_levels.map(q => { + const idx = this.quantile_levels.indexOf(q); + if (idx === -1) throw new Error(`Quantile level ${q} not found`); + return idx; + }); + + const data = quantile_preds.data; + // Truncate to requested prediction_length (model may generate more due to patching) + const actual_length = Math.min(prediction_length, pred_length); + for (let b = 0; b < batchSize; b++) { + const forecast = []; + for (let t = 0; t < actual_length; t++) { + const row = quantile_indices.map(qi => { + const idx = b * (num_quantiles * pred_length) + qi * pred_length + t; + return data[idx]; + }); + forecast.push(row); + } + + allResults.push({ + forecast, + quantile_levels, + }); + } + } + + return allResults; + } +} + const SUPPORTED_TASKS = Object.freeze({ "text-classification": { "tokenizer": AutoTokenizer, @@ -3351,6 +3843,14 @@ const SUPPORTED_TASKS = Object.freeze({ }, "type": "image", }, + "time-series-forecasting": { + "pipeline": TimeSeriesForecastingPipeline, + "model": AutoModelForForecasting, + "default": { + "model": "kashif/chronos-2-onnx", + }, + "type": "time-series", + }, }) @@ -3542,3 +4042,7 @@ async function loadItems(mapping, model, pretrainedOptions) { return result; } + + +// Backward compatibility alias +export const Chronos2Pipeline = TimeSeriesForecastingPipeline;