Skip to content

Commit 35df16f

Browse files
committed
chore(tests): standardize mixture and parameter tests got dtype
1 parent e8bed63 commit 35df16f

File tree

4 files changed

+152
-199
lines changed

4 files changed

+152
-199
lines changed

rework_pysatl_mpest/core/mixture.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _validate_weights(self, n_components: int, weights: NDArray[DType]):
134134
@property
135135
def dtype(self) -> type[DType]:
136136
"""type[DType]: The numpy data type of the mixture's outputs."""
137+
137138
return self._dtype
138139

139140
@property
@@ -396,7 +397,8 @@ def _get_sorted_pairs(self, for_hashing: bool = False) -> list[tuple["Continuous
396397
if self._sorted_pairs_cache is None or for_hashing:
397398
weights_to_use = self.weights
398399
if for_hashing:
399-
weights_to_use = np.round(weights_to_use, 8)
400+
decimals = np.finfo(self.dtype).precision
401+
weights_to_use = np.round(weights_to_use, decimals)
400402

401403
pairs = sorted(zip(self.components, weights_to_use), key=lambda p: hash(p[0]))
402404
if not for_hashing:

rework_pysatl_mpest/distributions/uniform.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def pdf(self, X):
9797
9898
Returns
9999
-------
100-
NDArray[np.float64]
100+
NDArray[DType]
101101
The PDF values corresponding to each point in :attr:`X`.
102102
"""
103103
X = np.asarray(X, dtype=self.dtype)
@@ -128,7 +128,7 @@ def ppf(self, P):
128128
129129
Returns
130130
-------
131-
NDArray[np.float64]
131+
NDArray[DType]
132132
The PPF values corresponding to each probability in :attr:`P`.
133133
"""
134134
P = np.asarray(P, dtype=self.dtype)
@@ -158,7 +158,7 @@ def lpdf(self, X):
158158
159159
Returns
160160
-------
161-
NDArray[np.float64]
161+
NDArray[DType]
162162
The log-PDF values corresponding to each point in :attr:`X`.
163163
"""
164164
X = np.asarray(X, dtype=self.dtype)
@@ -214,7 +214,7 @@ def log_gradients(self, X):
214214
215215
Returns
216216
-------
217-
NDArray[np.float64]
217+
NDArray[DType]
218218
An array where each row corresponds to a data point in :attr:`X`
219219
and each column corresponds to the gradient with respect to a
220220
specific optimizable parameter. The order of columns corresponds
@@ -247,7 +247,7 @@ def generate(self, size: int):
247247
248248
Returns
249249
-------
250-
NDArray[np.float64]
250+
NDArray[DType]
251251
A NumPy array containing the generated samples.
252252
"""
253253

0 commit comments

Comments
 (0)