Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ function getNormalizedConfig(config) {

// Encoder-decoder models
case 't5':
case 'chronos2':
case 'mt5':
case 'longt5':
mapping['num_decoder_layers'] = 'num_decoder_layers';
Expand Down
82 changes: 82 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,57 @@
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Chronos2 models
/**
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
*/
export class Chronos2PreTrainedModel extends PreTrainedModel {
forward_params = [
'context',
'group_ids',
'attention_mask',
];
};

/**
* The Chronos-2 Model for time series forecasting.
*
* Chronos-2 is a family of pretrained time series forecasting models based on T5.
* It uses a patching mechanism to convert time series into tokens and predicts
* multiple quantiles for probabilistic forecasting.
*
* **Example:** Load and run a Chronos-2 model for forecasting.
*
* ```javascript
* import { Chronos2ForForecasting } from '@huggingface/transformers';
*
* const model = await Chronos2ForForecasting.from_pretrained('amazon/chronos-2-small');
*
* // Prepare time series input
* const context = new Float32Array([1.0, 2.0, 3.0, 4.0, ...]); // Your historical data
* const inputs = {
* context: context,
* group_ids: new BigInt64Array([0]), // Group ID for cross-learning
* attention_mask: new Float32Array(context.length).fill(1.0),
* };
*
* // Generate forecasts
* const { quantile_preds } = await model(inputs);
* // Returns quantile predictions: [batch_size, num_quantiles, prediction_length]
* ```
*/
export class Chronos2Model extends Chronos2PreTrainedModel { }

/**
* Chronos2 Model with a forecasting head for time series prediction.
*
* This model outputs quantile predictions for probabilistic forecasting.
*/
export class Chronos2ForForecasting extends Chronos2PreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// LONGT5 models
/**
Expand Down Expand Up @@ -7839,6 +7890,7 @@

const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
['t5', ['T5Model', T5Model]],
['chronos2', ['Chronos2Model', Chronos2Model]],
['longt5', ['LongT5Model', LongT5Model]],
['mt5', ['MT5Model', MT5Model]],
['bart', ['BartModel', BartModel]],
Expand Down Expand Up @@ -7957,6 +8009,7 @@

const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]],
['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]],
['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]],
['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]],
['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]],
Expand Down Expand Up @@ -8226,6 +8279,10 @@
['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]],
])

const MODEL_FOR_FORECASTING_MAPPING_NAMES = new Map([
['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]],
])

const MODEL_CLASS_TYPE_MAPPING = [
// MODEL_MAPPING_NAMES:
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
Expand Down Expand Up @@ -8263,6 +8320,7 @@
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_FORECASTING_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],

// Custom:
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
Expand Down Expand Up @@ -8565,6 +8623,30 @@
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES];
}

/**
* Helper class which is used to instantiate time series forecasting models with the `from_pretrained` function.
*
* @example
* const model = await AutoModelForForecasting.from_pretrained('amazon/chronos-2-small');
*/
export class AutoModelForForecasting extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_FORECASTING_MAPPING_NAMES];

/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// First, load the config to check if it has chronos_config
const config = options.config || await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

// If model has chronos_config, route to Chronos2ForForecasting regardless of model_type
if (config.chronos_config) {

Check failure on line 8641 in src/models.js

View workflow job for this annotation

GitHub Actions / build (20)

Property 'chronos_config' does not exist on type 'PretrainedConfig'.

Check failure on line 8641 in src/models.js

View workflow job for this annotation

GitHub Actions / build (22)

Property 'chronos_config' does not exist on type 'PretrainedConfig'.

Check failure on line 8641 in src/models.js

View workflow job for this annotation

GitHub Actions / build (18)

Property 'chronos_config' does not exist on type 'PretrainedConfig'.
return await Chronos2ForForecasting.from_pretrained(pretrained_model_name_or_path, { ...options, config });
}

// Otherwise, use the standard mapping-based routing
return await super.from_pretrained(pretrained_model_name_or_path, { ...options, config });
}
}

//////////////////////////////////////////////////

//////////////////////////////////////////////////
Expand Down
Loading
Loading