@@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform):
6565 def __init__ (
6666 self ,
6767 * ,
68- spike_scale : float | np .ndarray | None = None ,
69- slab_scale : float | np .ndarray | None = None ,
68+ spike_scale : np .typing . ArrayLike | None = None ,
69+ slab_scale : np .typing . ArrayLike | None = None ,
7070 per_dimension : bool = True ,
7171 seed : int | None = None ,
7272 ):
@@ -80,14 +80,14 @@ def __init__(
8080 def _resolve_scale (
8181 self ,
8282 name : str ,
83- passed : float | np .ndarray | None ,
83+ passed : np .typing . ArrayLike | None ,
8484 default : float ,
8585 data : np .ndarray ,
8686 ) -> np .ndarray | float :
8787 """
8888 Determine spike/slab scale:
89- - If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90- - Else: validate & cast passed to the correct shape/type.
89+ - If ` passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
90+ - Else: Validate & cast ` passed` to the correct shape/type.
9191
9292 Parameters
9393 ----------
@@ -103,8 +103,8 @@ def _resolve_scale(
103103
104104 Returns
105105 -------
106- float or np.ndarray
107- The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
106+ np.ndarray
107+ The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
108108 (if per_dimension=True).
109109 """
110110
@@ -119,22 +119,22 @@ def _resolve_scale(
119119
120120 # If no scale is passed, determine scale automatically given the dimensionwise or global std
121121 if passed is None :
122- return default * std
122+ return np . array ( default * std )
123123 # If a scale is passed, check if the passed shape matches the expected shape
124124 else :
125- if self . per_dimension :
125+ try :
126126 arr = np .asarray (passed , dtype = float )
127- if arr .shape != expected_shape or arr .ndim != 1 :
127+ except Exception as e :
128+ raise TypeError (f"{ name } : expected values convertible to float, got { type (passed ).__name__ } " ) from e
129+
130+ if self .per_dimension :
131+ if arr .ndim != 1 or arr .shape != expected_shape :
128132 raise ValueError (f"{ name } : expected array of shape { expected_shape } , got { arr .shape } " )
129133 return arr
130134 else :
131- try :
132- scalar = float (passed )
133- except TypeError :
134- raise TypeError (f"{ name } : expected a scalar convertible to float, got type { type (passed ).__name__ } " )
135- except ValueError :
136- raise ValueError (f"{ name } : expected a scalar convertible to float, got value { passed !r} " )
137- return scalar
135+ if arr .ndim != 0 :
136+ raise ValueError (f"{ name } : expected scalar, got array of shape { arr .shape } " )
137+ return arr
138138
139139 def forward (self , data : np .ndarray , stage : str = "inference" , ** kwargs ) -> np .ndarray :
140140 """
@@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
173173 return data + noise
174174
175175 def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
176- """ Non-invertible transform."""
176+ # Non-invertible transform
177177 return data
178178
179179 def get_config (self ) -> dict :
0 commit comments