@@ -168,63 +168,31 @@ end
168168
169169function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
170170
171+ if length (II) == 1 || Threads. nthreads () == 1
172+ return solve_batch (prob,alg,EnsembleSerial (),II,pmap_batch_size;kwargs... )
173+ end
174+
171175 if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
172176 probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
173177 else
174178 probs = prob. prob
175179 end
176180
177- function multithreaded_batch (batch_idx)
178- i = II[batch_idx]
179- iter = 1
180- _prob = if prob. safetycopy
181- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
182- else
183- probs isa Vector ? probs[Threads. threadid ()] : probs
184- end
185- new_prob = prob. prob_func (_prob,i,iter)
186- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
187- if ! (typeof (x) <: Tuple )
188- @warn (" output_func should return (out,rerun). See docs for updated details" )
189- _x = (x,false )
190- else
191- _x = x
192- end
193- rerun = _x[2 ]
194-
195- while rerun
196- iter += 1
197- _prob = if prob. safetycopy
198- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
199- else
200- probs isa Vector ? probs[Threads. threadid ()] : probs
201- end
202- new_prob = prob. prob_func (_prob,i,iter)
203- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
204- if ! (typeof (x) <: Tuple )
205- @warn (" output_func should return (out,rerun). See docs for updated details" )
206- _x = (x,false )
207- else
208- _x = x
209- end
210- rerun = _x[2 ]
211- end
212- _x[1 ]
213- end
214-
215181 # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216- batch_data = Vector {Any} (undef,length (II))
182+
183+ batch_data = Vector {Any} (undef,Threads. nthreads ())
184+ batch_size = length (II)÷ Threads. nthreads ()
217185
218186 let
219- if length (II) == 1 || Threads. nthreads () == 1
220- for batch_idx in axes (batch_data, 1 )
221- batch_data[batch_idx] = multithreaded_batch (batch_idx)
222- end
223- else
224- Threads. @threads for batch_idx in axes (batch_data, 1 )
225- batch_data[batch_idx] = multithreaded_batch (batch_idx)
226- end
187+ Threads. @threads for i in 1 : Threads. nthreads ()
188+ if i == Threads. nthreads ()
189+ I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
190+ else
191+ I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
192+ end
193+ batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
227194 end
195+ batch_data = reduce (vcat,batch_data)
228196 end
229197 tighten_container_eltype (batch_data)
230198end
@@ -240,71 +208,8 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
240208 else
241209 I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
242210 end
243- thread_monte (prob,I_local, alg,i ;kwargs... )
211+ solve_batch (prob,alg,EnsembleThreads (),I_local,pmap_batch_size ;kwargs... )
244212 end
245213 end
246214 reduce (vcat,batch_data)
247215end
248-
249- function thread_monte (prob,II,alg,procid;kwargs... )
250-
251- if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
252- probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
253- else
254- probs = prob. prob
255- end
256-
257- function multithreaded_batch (j)
258- i = II[j]
259- iter = 1
260- _prob = if prob. safetycopy
261- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
262- else
263- probs isa Vector ? probs[Threads. threadid ()] : probs
264- end
265- new_prob = prob. prob_func (_prob,i,iter)
266- rerun = true
267- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
268- if ! (typeof (x) <: Tuple )
269- @warn (" output_func should return (out,rerun). See docs for updated details" )
270- _x = (x,false )
271- else
272- _x = x
273- end
274- rerun = _x[2 ]
275- while rerun
276- iter += 1
277- _prob = if prob. safetycopy
278- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
279- else
280- probs isa Vector ? probs[Threads. threadid ()] : probs
281- end
282- new_prob = prob. prob_func (_prob,i,iter)
283- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
284- if ! (typeof (x) <: Tuple )
285- @warn (" output_func should return (out,rerun). See docs for updated details" )
286- _x = (x,false )
287- else
288- _x = x
289- end
290- rerun = _x[2 ]
291- end
292- _x[1 ]
293- end
294-
295- # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296- batch_data = Vector {Any} (undef,length (II))
297-
298- let
299- if length (II) == 1 || Threads. nthreads () == 1
300- for batch_idx in axes (batch_data, 1 )
301- batch_data[batch_idx] = multithreaded_batch (batch_idx)
302- end
303- else
304- Threads. @threads for batch_idx in axes (batch_data, 1 )
305- batch_data[batch_idx] = multithreaded_batch (batch_idx)
306- end
307- end
308- end
309- tighten_container_eltype (batch_data)
310- end
0 commit comments