@@ -99,30 +99,77 @@ class AlphaSampling(
9999): # pylint: disable=too-few-public-methods
100100 """as in [Matsushima et al. 2023](https://doi.org/10.5194/gmd-16-6211-2023)"""
101101
102- def __init__ (self , spectrum , alpha , size_range = None ):
102+ def __init__ (self , spectrum , alpha , size_range = None , dist_0 = None , dist_1 = None ):
103103 super ().__init__ (spectrum , size_range )
104104 self .alpha = alpha
105+ if dist_0 is None :
106+ dist_0 = self .spectrum
107+ if dist_1 is None :
108+ def dist_1_inv (y ):
109+ return (self .size_range [1 ] - self .size_range [0 ]) * y
110+ else :
111+ dist_1_inv = dist_1 .percentiles
112+ self .dist_0_cdf = dist_0 .cdf
113+ self .dist_1_inv = dist_1_inv
105114
106- def sample (self , n_sd , * , backend = None ): # pylint: disable=unused-argument
107- x_prime = np .linspace (
108- self .size_range [0 ], self .size_range [1 ], num = n_sd + 1
109- ) # maybe doesnt need to be so many, just for interpolation
110- sd_cdf = self .spectrum .cdf (x_prime )
111-
112- def Fb2_inv (y ):
113- return (x_prime [- 1 ] - x_prime [0 ]) * y
115+ def sample (self , n_sd , * , backend = None ,xprime = None ): # pylint: disable=unused-argument
116+ if xprime is None :
117+ x_prime = np .linspace (
118+ self .size_range [0 ], self .size_range [1 ], num = 2 * n_sd
119+ ) # maybe doesnt need to be so many, just for interpolation
120+ sd_cdf = self .dist_0_cdf (x_prime )
114121
115- x_sd_cdf = (1 - self .alpha ) * x_prime + self .alpha * Fb2_inv (sd_cdf )
122+ x_sd_cdf = (1 - self .alpha ) * x_prime + self .alpha * self . dist_1_inv (sd_cdf )
116123
117124 inv_cdf = interp1d (sd_cdf , x_sd_cdf )
118125
126+ percent_values = self ._find_percentiles (n_sd , backend )
127+ percentiles = inv_cdf (percent_values )
128+
129+ return self ._sample (percentiles , self .spectrum )
130+
131+ def _find_percentiles (self , n_sd , backend ):
119132 percent_values = np .linspace (
120133 default_cdf_range [0 ], default_cdf_range [1 ], num = 2 * n_sd + 1
121134 )
122- percentiles = inv_cdf (percent_values )
135+ return percent_values
136+
137+ class AlphaSamplingPseudoRandom (
138+ AlphaSampling
139+ ): # pylint: disable=too-few-public-methods
140+ """Alpha sampling with pseudo-random values within deterministic percentile bins"""
141+
142+ def _find_percentiles (self , n_sd , backend ):
143+ num_elements = n_sd
144+ storage = backend .Storage .empty (num_elements , dtype = float )
145+ backend .Random (seed = backend .formulae .seed , size = num_elements )(storage )
146+ u01 = storage .to_ndarray ()
123147
124- return self ._sample (percentiles , self .spectrum )
148+ percent_values = np .linspace (
149+ default_cdf_range [0 ], default_cdf_range [1 ], num = 2 * n_sd + 1
150+ )
151+
152+ for i in range (1 , len (percent_values ) - 1 , 2 ):
153+ percent_values [i ] = (
154+ percent_values [i - 1 ] + u01 [i // 2 ] * (percent_values [i + 1 ] - percent_values [i - 1 ])
155+ )
156+
157+ return percent_values
158+
159+ class AlphaSamplingRandom (
160+ AlphaSampling
161+ ): # pylint: disable=too-few-public-methods
162+ """Alpha sampling with uniform random percentile bins"""
163+
164+ def _find_percentiles (self , n_sd , backend ):
165+ num_elements = 2 * n_sd + 1
166+ storage = backend .Storage .empty (num_elements , dtype = float )
167+ backend .Random (seed = backend .formulae .seed , size = num_elements )(storage )
168+ u01 = storage .to_ndarray ()
125169
170+ percent_values = np .sort (default_cdf_range [0 ] + u01 * (default_cdf_range [1 ] - default_cdf_range [0 ]))
171+ return percent_values
172+
126173
127174class ConstantMultiplicity (AlphaSampling ): # pylint: disable=too-few-public-methods
128175 def __init__ (self , spectrum , size_range = None ):
0 commit comments