Skip to content

Commit b8dbacc

Browse files
authored
[BUG] Fix issue with plot_prediction_actual_by_variable unsupported operand type(s) for *: 'numpy.ndarray' and 'Tensor' (#1903)
fixes #1822 This PR fixes the issue with `plot_prediction_actual_by_variable` where the input `np.ndarray` was being multiplied to `torch.Tensor` in `TorchNormalizer` `__call__`
1 parent e7a9c5f commit b8dbacc

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -757,29 +757,33 @@ def transform(
757757
else:
758758
return y
759759

760-
def inverse_transform(self, y: torch.Tensor) -> torch.Tensor:
760+
def inverse_transform(self, y: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
761761
"""
762762
Inverse scale.
763763
764764
Parameters
765765
----------
766-
y: torch.Tensor
766+
y: Union[torch.Tensor, np.ndarray])
767767
scaled data
768768
769769
Returns
770770
-------
771771
torch.Tensor
772772
de-scaled data
773773
"""
774+
if isinstance(y, np.ndarray):
775+
y = torch.from_numpy(y)
774776
return self(dict(prediction=y, target_scale=self.get_parameters().unsqueeze(0)))
775777

776-
def __call__(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
778+
def __call__(
779+
self, data: dict[str, Union[torch.Tensor, np.ndarray]]
780+
) -> torch.Tensor:
777781
"""
778782
Inverse transformation but with network output as input.
779783
780784
Parameters
781785
----------
782-
data: Dict[str, torch.Tensor]
786+
data: dict[str, Union[torch.Tensor, np.ndarray]]
783787
Dictionary with entries
784788
785789
* prediction: data to de-scale
@@ -795,24 +799,26 @@ def __call__(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
795799
if isinstance(dtype, np.dtype):
796800
# convert the array into tensor if it is a numpy array
797801
data["prediction"] = torch.as_tensor(data["prediction"])
798-
dtype = data["prediction"].dtype
802+
803+
prediction = data["prediction"]
799804

800805
# inverse transformation with tensors
801806
norm = data["target_scale"]
802-
807+
if isinstance(norm, np.ndarray):
808+
norm = torch.from_numpy(norm)
803809
# use correct shape for norm
804-
if data["prediction"].ndim > norm.ndim:
810+
if prediction.ndim > norm.ndim:
805811
norm = norm.unsqueeze(-1)
806812

807813
# transform
808-
y = data["prediction"] * norm[:, 1, None] + norm[:, 0, None]
814+
y = prediction * norm[:, 1, None] + norm[:, 0, None]
809815

810816
y = self.inverse_preprocess(y)
811817

812818
# return correct shape
813-
if data["prediction"].ndim == 1 and y.ndim > 1:
819+
if prediction.ndim == 1 and y.ndim > 1:
814820
y = y.squeeze(0)
815-
return y.type(dtype)
821+
return y.type(prediction.dtype)
816822

817823

818824
class EncoderNormalizer(TorchNormalizer):

tests/test_data/test_encoders.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,51 @@ def test_NaNLabelEncoder_add():
6262
[
6363
dict(method="robust"),
6464
dict(method="robust", method_kwargs=dict(upper=1.0, lower=0.0)),
65-
dict(method="robust", data=np.random.randn(100)),
66-
dict(data=np.random.randn(100)),
65+
dict(method="robust"),
6766
dict(transformation="log"),
6867
dict(transformation="softplus"),
6968
dict(transformation="log1p"),
7069
dict(transformation="relu"),
7170
dict(method="identity"),
72-
dict(method="identity", data=np.random.randn(100)),
71+
dict(
72+
method="identity",
73+
),
7374
dict(center=False),
7475
dict(max_length=5),
75-
dict(data=pd.Series(np.random.randn(100))),
7676
dict(max_length=[1, 2]),
7777
],
7878
)
79-
def test_EncoderNormalizer(kwargs):
79+
@pytest.mark.parametrize("data_type", ["torch", "numpy", "pandas"])
80+
def test_EncoderNormalizer(kwargs, data_type):
81+
transformation = kwargs.get("transformation")
82+
83+
if transformation in ["log", "log1p", "softplus", "relu"]:
84+
base_data = np.random.uniform(0.1, 10, size=100) # strictly positive
85+
else:
86+
base_data = np.random.randn(100)
87+
88+
if data_type == "torch":
89+
data = torch.tensor(base_data, dtype=torch.float32)
90+
elif data_type == "numpy":
91+
data = base_data.astype(np.float32)
92+
elif data_type == "pandas":
93+
data = pd.Series(base_data.astype(np.float32))
8094
kwargs.setdefault("method", "standard")
8195
kwargs.setdefault("center", True)
82-
kwargs.setdefault("data", torch.rand(100))
83-
data = kwargs.pop("data")
8496

8597
normalizer = EncoderNormalizer(**kwargs)
98+
transformed = normalizer.fit_transform(data)
99+
inverse = normalizer.inverse_transform(torch.as_tensor(transformed))
86100

87101
if kwargs.get("transformation") in ["relu", "softplus", "log1p"]:
88102
assert (
89-
normalizer.inverse_transform(
90-
torch.as_tensor(normalizer.fit_transform(data))
91-
)
92-
>= 0
103+
inverse >= 0
93104
).all(), "Inverse transform should yield only positive values"
94105
else:
106+
expected = torch.as_tensor(data)
95107
assert torch.isclose(
96-
normalizer.inverse_transform(
97-
torch.as_tensor(normalizer.fit_transform(data))
98-
),
99-
torch.as_tensor(data),
108+
inverse,
109+
expected,
100110
atol=1e-5,
101111
).all(), "Inverse transform should reverse transform"
102112

0 commit comments

Comments
 (0)