Skip to content

Commit 3beeb0f

Browse files
committed
add random,pseudorandom, and endpoint dist options
1 parent ad154f7 commit 3beeb0f

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

PySDM/initialisation/sampling/spectral_sampling.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,30 +99,77 @@ class AlphaSampling(
9999
): # pylint: disable=too-few-public-methods
100100
"""as in [Matsushima et al. 2023](https://doi.org/10.5194/gmd-16-6211-2023)"""
101101

102-
def __init__(self, spectrum, alpha, size_range=None):
102+
def __init__(self, spectrum, alpha, size_range=None,dist_0=None,dist_1=None):
103103
super().__init__(spectrum, size_range)
104104
self.alpha = alpha
105+
if dist_0 is None:
106+
dist_0 = self.spectrum
107+
if dist_1 is None:
108+
def dist_1_inv(y):
109+
return (self.size_range[1] - self.size_range[0]) * y
110+
else:
111+
dist_1_inv = dist_1.percentiles
112+
self.dist_0_cdf = dist_0.cdf
113+
self.dist_1_inv = dist_1_inv
105114

106-
def sample(self, n_sd, *, backend=None): # pylint: disable=unused-argument
107-
x_prime = np.linspace(
108-
self.size_range[0], self.size_range[1], num=n_sd + 1
109-
) # maybe doesnt need to be so many, just for interpolation
110-
sd_cdf = self.spectrum.cdf(x_prime)
111-
112-
def Fb2_inv(y):
113-
return (x_prime[-1] - x_prime[0]) * y
115+
def sample(self, n_sd, *, backend=None,xprime=None): # pylint: disable=unused-argument
116+
if xprime is None:
117+
x_prime = np.linspace(
118+
self.size_range[0], self.size_range[1], num=2*n_sd
119+
) # maybe doesnt need to be so many, just for interpolation
120+
sd_cdf = self.dist_0_cdf(x_prime)
114121

115-
x_sd_cdf = (1 - self.alpha) * x_prime + self.alpha * Fb2_inv(sd_cdf)
122+
x_sd_cdf = (1 - self.alpha) * x_prime + self.alpha * self.dist_1_inv(sd_cdf)
116123

117124
inv_cdf = interp1d(sd_cdf, x_sd_cdf)
118125

126+
percent_values = self._find_percentiles(n_sd, backend)
127+
percentiles = inv_cdf(percent_values)
128+
129+
return self._sample(percentiles, self.spectrum)
130+
131+
def _find_percentiles(self, n_sd, backend):
119132
percent_values = np.linspace(
120133
default_cdf_range[0], default_cdf_range[1], num=2 * n_sd + 1
121134
)
122-
percentiles = inv_cdf(percent_values)
135+
return percent_values
136+
137+
class AlphaSamplingPseudoRandom(
138+
AlphaSampling
139+
): # pylint: disable=too-few-public-methods
140+
"""Alpha sampling with pseudo-random values within deterministic percentile bins"""
141+
142+
def _find_percentiles(self, n_sd, backend):
143+
num_elements = n_sd
144+
storage = backend.Storage.empty(num_elements, dtype=float)
145+
backend.Random(seed=backend.formulae.seed, size=num_elements)(storage)
146+
u01 = storage.to_ndarray()
123147

124-
return self._sample(percentiles, self.spectrum)
148+
percent_values = np.linspace(
149+
default_cdf_range[0], default_cdf_range[1], num=2 * n_sd + 1
150+
)
151+
152+
for i in range(1, len(percent_values) - 1, 2):
153+
percent_values[i] = (
154+
percent_values[i - 1] + u01[i // 2] * (percent_values[i + 1] - percent_values[i - 1])
155+
)
156+
157+
return percent_values
158+
159+
class AlphaSamplingRandom(
160+
AlphaSampling
161+
): # pylint: disable=too-few-public-methods
162+
"""Alpha sampling with uniform random percentile bins"""
163+
164+
def _find_percentiles(self, n_sd, backend):
165+
num_elements = 2 * n_sd + 1
166+
storage = backend.Storage.empty(num_elements, dtype=float)
167+
backend.Random(seed=backend.formulae.seed, size=num_elements)(storage)
168+
u01 = storage.to_ndarray()
125169

170+
percent_values = np.sort(default_cdf_range[0] + u01 * (default_cdf_range[1] - default_cdf_range[0]))
171+
return percent_values
172+
126173

127174
class ConstantMultiplicity(AlphaSampling): # pylint: disable=too-few-public-methods
128175
def __init__(self, spectrum, size_range=None):

0 commit comments

Comments
 (0)