Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions pytorch_forecasting/models/base/base_model_refactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
########################################################################################
# Disclaimer: This baseclass is still work in progress and experimental, please
# use with care. This class is a basic skeleton of how the base classes may look like
# in the version-2.
########################################################################################


from typing import Dict, List, Optional, Tuple, Union

from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
from torch.optim import Optimizer


class BaseModel(LightningModule):
def __init__(
self,
loss: nn.Module,
logging_metrics: Optional[List[nn.Module]] = None,
optimizer: Optional[Union[Optimizer, str]] = "adam",
optimizer_params: Optional[Dict] = None,
lr_scheduler: Optional[str] = None,
lr_scheduler_params: Optional[Dict] = None,
):
"""
Base model for time series forecasting.

Parameters
----------
loss : nn.Module
Loss function to use for training.
logging_metrics : Optional[List[nn.Module]], optional
List of metrics to log during training, validation, and testing.
optimizer : Optional[Union[Optimizer, str]], optional
Optimizer to use for training.
Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
optimizer_params : Optional[Dict], optional
Parameters for the optimizer.
lr_scheduler : Optional[str], optional
Learning rate scheduler to use.
Supported values: "reduce_lr_on_plateau", "step_lr".
lr_scheduler_params : Optional[Dict], optional
Parameters for the learning rate scheduler.
"""
super().__init__()
self.loss = loss
self.logging_metrics = logging_metrics if logging_metrics is not None else []
self.optimizer = optimizer
self.optimizer_params = optimizer_params if optimizer_params is not None else {}
self.lr_scheduler = lr_scheduler
self.lr_scheduler_params = (
lr_scheduler_params if lr_scheduler_params is not None else {}
)

def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass of the model.

Parameters
----------
x : Dict[str, torch.Tensor]
Dictionary containing input tensors

Returns
-------
Dict[str, torch.Tensor]
Dictionary containing output tensors
"""
raise NotImplementedError("Forward method must be implemented by subclass.")

Check warning on line 71 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L71

Added line #L71 was not covered by tests

def training_step(
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Training step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 95 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L91-L95

Added lines #L91 - L95 were not covered by tests
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="train")
return {"loss": loss}

Check warning on line 99 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L98-L99

Added lines #L98 - L99 were not covered by tests

def validation_step(
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Validation step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 123 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L119-L123

Added lines #L119 - L123 were not covered by tests
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="val")
return {"val_loss": loss}

Check warning on line 127 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L126-L127

Added lines #L126 - L127 were not covered by tests

def test_step(
self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Test step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.

Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(

Check warning on line 151 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L147-L151

Added lines #L147 - L151 were not covered by tests
"test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="test")
return {"test_loss": loss}

Check warning on line 155 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L154-L155

Added lines #L154 - L155 were not covered by tests

def predict_step(
self,
batch: Tuple[Dict[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> torch.Tensor:
"""
Prediction step for the model.

Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input tensors.
batch_idx : int
Index of the batch.
dataloader_idx : int
Index of the dataloader.

Returns
-------
torch.Tensor
Predicted output tensor.
"""
x, _ = batch
y_hat = self(x)
return y_hat

Check warning on line 182 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L180-L182

Added lines #L180 - L182 were not covered by tests

def configure_optimizers(self) -> Dict:
"""
Configure the optimizer and learning rate scheduler.

Returns
-------
Dict
Dictionary containing the optimizer and scheduler configuration.
"""
optimizer = self._get_optimizer()
if self.lr_scheduler is not None:
scheduler = self._get_scheduler(optimizer)
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
return {

Check warning on line 197 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L193-L197

Added lines #L193 - L197 were not covered by tests
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
else:
return {"optimizer": optimizer, "lr_scheduler": scheduler}
return {"optimizer": optimizer}

Check warning on line 206 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L205-L206

Added lines #L205 - L206 were not covered by tests

def _get_optimizer(self) -> Optimizer:
"""
Get the optimizer based on the specified optimizer name and parameters.

Returns
-------
Optimizer
The optimizer instance.
"""
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adam":
return torch.optim.Adam(self.parameters(), **self.optimizer_params)
elif self.optimizer.lower() == "sgd":
return torch.optim.SGD(self.parameters(), **self.optimizer_params)

Check warning on line 221 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L217-L221

Added lines #L217 - L221 were not covered by tests
else:
raise ValueError(f"Optimizer {self.optimizer} not supported.")
elif isinstance(self.optimizer, Optimizer):
return self.optimizer

Check warning on line 225 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L223-L225

Added lines #L223 - L225 were not covered by tests
else:
raise ValueError(

Check warning on line 227 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L227

Added line #L227 was not covered by tests
"Optimizer must be either a string or "
"an instance of torch.optim.Optimizer."
)

def _get_scheduler(
self, optimizer: Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
"""
Get the lr scheduler based on the specified scheduler name and params.

Parameters
----------
optimizer : Optimizer
The optimizer instance.

Returns
-------
torch.optim.lr_scheduler._LRScheduler
The learning rate scheduler instance.
"""
if self.lr_scheduler.lower() == "reduce_lr_on_plateau":
return torch.optim.lr_scheduler.ReduceLROnPlateau(

Check warning on line 249 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L248-L249

Added lines #L248 - L249 were not covered by tests
optimizer, **self.lr_scheduler_params
)
elif self.lr_scheduler.lower() == "step_lr":
return torch.optim.lr_scheduler.StepLR(

Check warning on line 253 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L252-L253

Added lines #L252 - L253 were not covered by tests
optimizer, **self.lr_scheduler_params
)
else:
raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")

Check warning on line 257 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L257

Added line #L257 was not covered by tests

def log_metrics(
self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val"
) -> None:
"""
Log additional metrics during training, validation, or testing.

Parameters
----------
y_hat : torch.Tensor
Predicted output tensor.
y : torch.Tensor
Target output tensor.
prefix : str
Prefix for the logged metrics (e.g., "train", "val", "test").
"""
for metric in self.logging_metrics:
metric_value = metric(y_hat, y)
self.log(

Check warning on line 276 in pytorch_forecasting/models/base/base_model_refactor.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/base_model_refactor.py#L274-L276

Added lines #L274 - L276 were not covered by tests
f"{prefix}_{metric.__class__.__name__}",
metric_value,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
Loading
Loading