1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from collections .abc import Iterable
15- from typing import TYPE_CHECKING , Literal
15+ from typing import TYPE_CHECKING , Literal , Protocol , cast
1616
1717from rich .box import SIMPLE_HEAD
1818from rich .console import Console
@@ -192,6 +192,132 @@ def callbacks(self, task: "Task"):
192192 self .finished_style = self .default_finished_style
193193
194194
195+ class ProgressBar (Protocol ):
196+ @property
197+ def tasks (self ):
198+ """Get the tasks in the progress bar."""
199+
200+ def add_task (self , * args , ** kwargs ):
201+ """Add a task to the progress bar."""
202+
203+ def update (self , task_id , ** kwargs ):
204+ """Update the task with the given ID with the provided keyword arguments."""
205+
206+ def __enter__ (self ):
207+ """Enter the context manager."""
208+
209+ def __exit__ (self , exc_type , exc_val , exc_tb ):
210+ """Exit the context manager."""
211+
212+
213+ def compute_draw_speed (elapsed , draws ):
214+ speed = draws / max (elapsed , 1e-6 )
215+
216+ if speed > 1 or speed == 0 :
217+ unit = "draws/s"
218+ else :
219+ unit = "s/draws"
220+ speed = 1 / speed
221+
222+ return speed , unit
223+
224+
225+ def create_rich_progress_bar (full_stats , step_columns , progressbar , progressbar_theme ):
226+ columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
227+
228+ if full_stats :
229+ columns += step_columns
230+
231+ columns += [
232+ TextColumn (
233+ "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
234+ table_column = Column ("Sampling Speed" , ratio = 1 ),
235+ ),
236+ TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
237+ TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
238+ ]
239+
240+ return CustomProgress (
241+ RecolorOnFailureBarColumn (
242+ table_column = Column ("Progress" , ratio = 2 ),
243+ failing_color = "tab:red" ,
244+ complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
245+ finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
246+ ),
247+ * columns ,
248+ console = Console (theme = progressbar_theme ),
249+ disable = not progressbar ,
250+ include_headers = True ,
251+ )
252+
253+
254+ class MarimoProgressTask :
255+ def __init__ (self , * args , ** kwargs ):
256+ self .args = args
257+ self .kwargs = kwargs
258+
259+ @property
260+ def chain_idx (self ) -> int :
261+ return self .kwargs .get ("chain_idx" , 0 )
262+
263+ @property
264+ def total (self ):
265+ return self .kwargs .get ("total" , 0 )
266+
267+ @property
268+ def elapsed (self ):
269+ return self .kwargs .get ("elapsed" , 0 )
270+
271+
272+ class MarimoProgressBar :
273+ def __init__ (self ) -> None :
274+ self .tasks = []
275+ self .divergences = {}
276+
277+ def __enter__ (self ):
278+ """Enter the context manager."""
279+ import marimo as mo
280+
281+ total_draws = (self .tasks [0 ].total + 1 ) * len (self .tasks )
282+
283+ self .bar = mo .status .progress_bar (total = total_draws , title = "Sampling PyMC model" )
284+
285+ def __exit__ (self , exc_type , exc_val , exc_tb ):
286+ """Exit the context manager."""
287+ self .bar ._finish ()
288+
289+ def add_task (self , * args , ** kwargs ):
290+ """Add a task to the progress bar."""
291+ task = MarimoProgressTask (* args , ** kwargs )
292+ self .tasks .append (task )
293+ return task
294+
295+ def update (self , task_id , ** kwargs ):
296+ """Update the task with the given ID with the provided keyword arguments."""
297+ if self .bar .progress .current >= cast (int , self .bar .progress .total ):
298+ return
299+
300+ self .divergences [task_id .chain_idx ] = kwargs .get ("divergences" , 0 )
301+
302+ total_divergences = sum (self .divergences .values ())
303+
304+ update_kwargs = {}
305+ if total_divergences > 0 :
306+ word = "draws" if total_divergences > 1 else "draw"
307+ update_kwargs ["subtitle" ] = f"{ total_divergences } diverging { word } "
308+
309+ self .bar .progress .update (** update_kwargs )
310+
311+
312+ def in_marimo_notebook () -> bool :
313+ try :
314+ import marimo as mo
315+
316+ return mo .running_in_notebook ()
317+ except ImportError :
318+ return False
319+
320+
195321class ProgressBarManager :
196322 """Manage progress bars displayed during sampling."""
197323
@@ -203,6 +329,7 @@ def __init__(
203329 tune : int ,
204330 progressbar : bool | ProgressBarType = True ,
205331 progressbar_theme : Theme | None = None ,
332+ progress : ProgressBar | None = None ,
206333 ):
207334 """
208335 Manage progress bars displayed during sampling.
@@ -275,11 +402,16 @@ def __init__(
275402
276403 progress_columns , progress_stats = step_method ._progressbar_config (chains )
277404
278- self ._progress = self .create_progress_bar (
279- progress_columns ,
280- progressbar = progressbar ,
281- progressbar_theme = progressbar_theme ,
282- )
405+ if in_marimo_notebook ():
406+ self .combined_progress = False
407+ self ._progress = MarimoProgressBar ()
408+ else :
409+ self ._progress = progress or create_rich_progress_bar (
410+ self .full_stats ,
411+ progress_columns ,
412+ progressbar = progressbar ,
413+ progressbar_theme = progressbar_theme ,
414+ )
283415 self .progress_stats = progress_stats
284416 self .update_stats_functions = step_method ._make_progressbar_update_functions ()
285417
@@ -331,18 +463,6 @@ def _initialize_tasks(self):
331463 for chain_idx in range (self .chains )
332464 ]
333465
334- @staticmethod
335- def compute_draw_speed (elapsed , draws ):
336- speed = draws / max (elapsed , 1e-6 )
337-
338- if speed > 1 or speed == 0 :
339- unit = "draws/s"
340- else :
341- unit = "s/draws"
342- speed = 1 / speed
343-
344- return speed , unit
345-
346466 def update (self , chain_idx , is_last , draw , tuning , stats ):
347467 if not self ._show_progress :
348468 return
@@ -353,7 +473,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
353473 chain_idx = 0
354474
355475 elapsed = self ._progress .tasks [chain_idx ].elapsed
356- speed , unit = self . compute_draw_speed (elapsed , draw )
476+ speed , unit = compute_draw_speed (elapsed , draw )
357477
358478 failing = False
359479 all_step_stats = {}
@@ -395,31 +515,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
395515 ** all_step_stats ,
396516 refresh = True ,
397517 )
398-
399- def create_progress_bar (self , step_columns , progressbar , progressbar_theme ):
400- columns = [TextColumn ("{task.fields[draws]}" , table_column = Column ("Draws" , ratio = 1 ))]
401-
402- if self .full_stats :
403- columns += step_columns
404-
405- columns += [
406- TextColumn (
407- "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}" ,
408- table_column = Column ("Sampling Speed" , ratio = 1 ),
409- ),
410- TimeElapsedColumn (table_column = Column ("Elapsed" , ratio = 1 )),
411- TimeRemainingColumn (table_column = Column ("Remaining" , ratio = 1 )),
412- ]
413-
414- return CustomProgress (
415- RecolorOnFailureBarColumn (
416- table_column = Column ("Progress" , ratio = 2 ),
417- failing_color = "tab:red" ,
418- complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
419- finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
420- ),
421- * columns ,
422- console = Console (theme = progressbar_theme ),
423- disable = not progressbar ,
424- include_headers = True ,
425- )
0 commit comments