Skip to content

Commit b75ea42

Browse files
standardize output format of xLSTMTime estimator for point predictions (#1978)
This PR standardizes the output format of the xLSTMTime estimator for point predictions. This is in accordance with issue #1976 --------- Co-authored-by: Aryan Saini <starktony.2032@gmail.com>
1 parent dfd2bd6 commit b75ea42

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

pytorch_forecasting/models/xlstm/_xlstm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def forward(
151151
output = output.transpose(1, 2)
152152

153153
output = output[0, ..., : self.hparams.output_size]
154+
output = output.unsqueeze(-1)
154155
return self.to_network_output(prediction=output)
155156

156157
@classmethod

pytorch_forecasting/models/xlstm/_xlstm_pkg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class xLSTMTime_pkg(_BasePtForecaster):
1717
"capability:pred_int": False,
1818
"capability:flexible_history_length": True,
1919
"capability:cold_start": False,
20-
"tests:skip_by_name": "test_integration",
2120
}
2221

2322
@classmethod

0 commit comments

Comments
 (0)