168168
169169function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
170170
171+ if length (II) == 1 || Threads. nthreads () == 1
172+ return solve_batch (prob,alg,EnsembleSerial (),II,pmap_batch_size;kwargs... )
173+ end
174+
171175 if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
172176 probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
173177 else
@@ -176,26 +180,19 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
176180
177181 # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
178182
179- local batch_data
183+ batch_data = Vector {Any} (undef,Threads. nthreads ())
184+ batch_size = length (II)÷ Threads. nthreads ()
185+
180186 let
181- if length (II) == 1 || Threads. nthreads () == 1
182- batch_data = Vector {Any} (undef,length (II))
183- for batch_idx in axes (batch_data, 1 )
184- batch_data[batch_idx] = multithreaded_batch (batch_idx,probs,alg,II)
185- end
186- else
187- batch_data = Vector {Any} (undef,Threads. nthreads ())
188- batch_size = length (II)÷ Threads. nthreads ()
189- Threads. @threads for i in 1 : Threads. nthreads ()
190- if i == Threads. nthreads ()
191- I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
192- else
193- I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
194- end
195- batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
196- end
197- batch_data = reduce (vcat,batch_data)
187+ Threads. @threads for i in 1 : Threads. nthreads ()
188+ if i == Threads. nthreads ()
189+ I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
190+ else
191+ I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
192+ end
193+ batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
198194 end
195+ batch_data = reduce (vcat,batch_data)
199196 end
200197 tighten_container_eltype (batch_data)
201198end
0 commit comments