@@ -708,47 +708,82 @@ def transform(
708708 target_scale = self .get_parameters ().numpy ()[None , :]
709709 center = target_scale [..., 0 ]
710710 scale = target_scale [..., 1 ]
711+
712+ if not isinstance (y , torch .Tensor ):
713+ if isinstance (y , (pd .Series )):
714+ index = y .index
715+ pandas_dtype = y .dtype
716+ y = y .values
717+ y_was = "pandas"
718+ y = torch .as_tensor (y )
719+ elif isinstance (y , np .ndarray ):
720+ y_was = "numpy"
721+ np_dtype = y .dtype
722+ try :
723+ y = torch .from_numpy (y )
724+ except TypeError :
725+ y = torch .as_tensor (y .astype (np .float32 ))
726+ else :
727+ y_was = "torch"
728+ torch_dtype = y .dtype
729+ if isinstance (center , np .ndarray ):
730+ center = torch .from_numpy (center )
731+ if isinstance (scale , np .ndarray ):
732+ scale = torch .from_numpy (scale )
711733 if y .ndim > center .ndim : # multiple batches -> expand size
712734 center = center .view (* center .size (), * (1 ,) * (y .ndim - center .ndim ))
713735 scale = scale .view (* scale .size (), * (1 ,) * (y .ndim - scale .ndim ))
714736
715- # transform
716- dtype = y .dtype
717737 y = (y - center ) / scale
718- try :
719- y = y .astype (dtype )
720- except AttributeError : # torch.Tensor has `.type()` instead of `.astype()`
721- y = y .type (dtype )
738+
739+ if y_was == "numpy" :
740+ numpy_data = y .numpy ()
741+ if np_dtype .kind in "iu" and numpy_data .dtype .kind == "f" :
742+ # Original was integer, but normalized data is float
743+ y = numpy_data .astype (np .float64 )
744+ else :
745+ y = numpy_data .astype (np_dtype )
746+ elif y_was == "pandas" :
747+ numpy_data = y .numpy ()
748+ if pandas_dtype .kind in "iu" and numpy_data .dtype .kind == "f" :
749+ pandas_dtype = np .float64
750+ y = pd .Series (numpy_data , index = index , dtype = pandas_dtype )
751+ else :
752+ y = y .type (torch_dtype )
722753
723754 # return with center and scale or without
724755 if return_norm :
725756 return y , target_scale
726757 else :
727758 return y
728759
729- def inverse_transform (self , y : torch .Tensor ) -> torch .Tensor :
760+ def inverse_transform (self , y : Union [ torch .Tensor , np . ndarray ] ) -> torch .Tensor :
730761 """
731762 Inverse scale.
732763
733764 Parameters
734765 ----------
735- y: torch.Tensor
766+ y: Union[ torch.Tensor, np.ndarray])
736767 scaled data
737768
738769 Returns
739770 -------
740771 torch.Tensor
741772 de-scaled data
742773 """
774+ if isinstance (y , np .ndarray ):
775+ y = torch .from_numpy (y )
743776 return self (dict (prediction = y , target_scale = self .get_parameters ().unsqueeze (0 )))
744777
745- def __call__ (self , data : dict [str , torch .Tensor ]) -> torch .Tensor :
778+ def __call__ (
779+ self , data : dict [str , Union [torch .Tensor , np .ndarray ]]
780+ ) -> torch .Tensor :
746781 """
747782 Inverse transformation but with network output as input.
748783
749784 Parameters
750785 ----------
751- data: Dict [str, torch.Tensor]
786+ data: dict [str, Union[ torch.Tensor, np.ndarray] ]
752787 Dictionary with entries
753788
754789 * prediction: data to de-scale
@@ -761,23 +796,29 @@ def __call__(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
761796 """
762797 # ensure output dtype matches input dtype
763798 dtype = data ["prediction" ].dtype
799+ if isinstance (dtype , np .dtype ):
800+ # convert the array into tensor if it is a numpy array
801+ data ["prediction" ] = torch .as_tensor (data ["prediction" ])
802+
803+ prediction = data ["prediction" ]
764804
765805 # inverse transformation with tensors
766806 norm = data ["target_scale" ]
767-
807+ if isinstance (norm , np .ndarray ):
808+ norm = torch .from_numpy (norm )
768809 # use correct shape for norm
769- if data [ " prediction" ] .ndim > norm .ndim :
810+ if prediction .ndim > norm .ndim :
770811 norm = norm .unsqueeze (- 1 )
771812
772813 # transform
773- y = data [ " prediction" ] * norm [:, 1 , None ] + norm [:, 0 , None ]
814+ y = prediction * norm [:, 1 , None ] + norm [:, 0 , None ]
774815
775816 y = self .inverse_preprocess (y )
776817
777818 # return correct shape
778- if data [ " prediction" ] .ndim == 1 and y .ndim > 1 :
819+ if prediction .ndim == 1 and y .ndim > 1 :
779820 y = y .squeeze (0 )
780- return y .type (dtype )
821+ return y .type (prediction . dtype )
781822
782823
783824class EncoderNormalizer (TorchNormalizer ):
0 commit comments