@@ -174,27 +174,56 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
174174 probs = prob. prob
175175 end
176176
177+ function multithreaded_batch (batch_idx)
178+ i = II[batch_idx]
179+ iter = 1
180+ _prob = if prob. safetycopy
181+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
182+ else
183+ probs isa Vector ? probs[Threads. threadid ()] : probs
184+ end
185+ new_prob = prob. prob_func (_prob,i,iter)
186+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
187+ if ! (typeof (x) <: Tuple )
188+ @warn (" output_func should return (out,rerun). See docs for updated details" )
189+ _x = (x,false )
190+ else
191+ _x = x
192+ end
193+ rerun = _x[2 ]
194+
195+ while rerun
196+ iter += 1
197+ _prob = if prob. safetycopy
198+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
199+ else
200+ probs isa Vector ? probs[Threads. threadid ()] : probs
201+ end
202+ new_prob = prob. prob_func (_prob,i,iter)
203+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
204+ if ! (typeof (x) <: Tuple )
205+ @warn (" output_func should return (out,rerun). See docs for updated details" )
206+ _x = (x,false )
207+ else
208+ _x = x
209+ end
210+ rerun = _x[2 ]
211+ end
212+ _x[1 ]
213+ end
214+
177215 # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216+ batch_data = Vector {Any} (undef,length (II))
178217
179- local batch_data
180218 let
181219 if length (II) == 1 || Threads. nthreads () == 1
182- batch_data = Vector {Any} (undef,length (II))
183220 for batch_idx in axes (batch_data, 1 )
184- batch_data[batch_idx] = multithreaded_batch (batch_idx,probs,alg,II )
221+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
185222 end
186223 else
187- batch_data = Vector {Any} (undef,Threads. nthreads ())
188- batch_size = length (II)÷ Threads. nthreads ()
189- Threads. @threads for i in 1 : Threads. nthreads ()
190- if i == Threads. nthreads ()
191- I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
192- else
193- I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
194- end
195- batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
224+ Threads. @threads for batch_idx in axes (batch_data, 1 )
225+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
196226 end
197- batch_data = reduce (vcat,batch_data)
198227 end
199228 end
200229 tighten_container_eltype (batch_data)
@@ -211,8 +240,71 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
211240 else
212241 I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
213242 end
214- solve_batch (prob,alg, EnsembleThreads (), I_local,pmap_batch_size ;kwargs... )
243+ thread_monte (prob,I_local,alg,i ;kwargs... )
215244 end
216245 end
217246 reduce (vcat,batch_data)
218247end
248+
249+ function thread_monte (prob,II,alg,procid;kwargs... )
250+
251+ if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
252+ probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
253+ else
254+ probs = prob. prob
255+ end
256+
257+ function multithreaded_batch (j)
258+ i = II[j]
259+ iter = 1
260+ _prob = if prob. safetycopy
261+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
262+ else
263+ probs isa Vector ? probs[Threads. threadid ()] : probs
264+ end
265+ new_prob = prob. prob_func (_prob,i,iter)
266+ rerun = true
267+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
268+ if ! (typeof (x) <: Tuple )
269+ @warn (" output_func should return (out,rerun). See docs for updated details" )
270+ _x = (x,false )
271+ else
272+ _x = x
273+ end
274+ rerun = _x[2 ]
275+ while rerun
276+ iter += 1
277+ _prob = if prob. safetycopy
278+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
279+ else
280+ probs isa Vector ? probs[Threads. threadid ()] : probs
281+ end
282+ new_prob = prob. prob_func (_prob,i,iter)
283+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
284+ if ! (typeof (x) <: Tuple )
285+ @warn (" output_func should return (out,rerun). See docs for updated details" )
286+ _x = (x,false )
287+ else
288+ _x = x
289+ end
290+ rerun = _x[2 ]
291+ end
292+ _x[1 ]
293+ end
294+
295+ # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296+ batch_data = Vector {Any} (undef,length (II))
297+
298+ let
299+ if length (II) == 1 || Threads. nthreads () == 1
300+ for batch_idx in axes (batch_data, 1 )
301+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
302+ end
303+ else
304+ Threads. @threads for batch_idx in axes (batch_data, 1 )
305+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
306+ end
307+ end
308+ end
309+ tighten_container_eltype (batch_data)
310+ end
0 commit comments