1313plt .switch_backend ("agg" )
1414
1515
16- class JupyterContainer :
17- def __init__ (
18- self ,
19- model_class ,
20- model_params ,
21- measures = None ,
22- name = "Mesa Model" ,
23- agent_portrayal = None ,
24- ):
25- self .model_class = model_class
26- self .split_model_params (model_params )
27- self .measures = measures
28- self .name = name
29- self .agent_portrayal = agent_portrayal
30- self .thread = None
31-
32- def split_model_params (self , model_params ):
33- self .model_params_input = {}
34- self .model_params_fixed = {}
35- for k , v in model_params .items ():
36- if self .check_param_is_fixed (v ):
37- self .model_params_fixed [k ] = v
16+ @solara .component
17+ def JupyterViz (
18+ model_class ,
19+ model_params ,
20+ measures = None ,
21+ name = "Mesa Model" ,
22+ agent_portrayal = None ,
23+ space_drawer = None ,
24+ play_interval = 400 ,
25+ ):
26+ current_step , set_current_step = solara .use_state (0 )
27+
28+ solara .Markdown (name )
29+
30+ # 0. Split model params
31+ model_params_input , model_params_fixed = split_model_params (model_params )
32+
33+ # 1. User inputs
34+ user_inputs = {}
35+ for k , v in model_params_input .items ():
36+ user_input = solara .use_reactive (v ["value" ])
37+ user_inputs [k ] = user_input .value
38+ make_user_input (user_input , k , v )
39+
40+ # 2. Model
41+ def make_model ():
42+ return model_class (** user_inputs , ** model_params_fixed )
43+
44+ model = solara .use_memo (make_model , dependencies = list (user_inputs .values ()))
45+
46+ # 3. Buttons
47+ ModelController (model , play_interval , current_step , set_current_step )
48+
49+ with solara .GridFixed (columns = 2 ):
50+ # 4. Space
51+ if space_drawer is None :
52+ make_space (model , agent_portrayal )
53+ else :
54+ space_drawer (model , agent_portrayal )
55+ # 5. Plots
56+ for measure in measures :
57+ if callable (measure ):
58+ # Is a custom object
59+ measure (model )
3860 else :
39- self .model_params_input [k ] = v
61+ make_plot (model , measure )
62+
4063
41- def check_param_is_fixed (self , param ):
42- if not isinstance (param , dict ):
43- return True
44- if "type" not in param :
45- return True
64+ @solara .component
65+ def ModelController (model , play_interval , current_step , set_current_step ):
66+ playing = solara .use_reactive (False )
67+ thread = solara .use_reactive (None )
68+
69+ def on_value_play (change ):
70+ if model .running :
71+ do_step ()
72+ else :
73+ playing .value = False
4674
47- def do_step (self ):
48- self . model .step ()
49- self . set_df ( self . model .datacollector . get_model_vars_dataframe () )
75+ def do_step ():
76+ model .step ()
77+ set_current_step ( model .schedule . steps )
5078
51- def do_play (self ):
52- self . model .running = True
53- while self . model .running :
54- self . do_step ()
79+ def do_play ():
80+ model .running = True
81+ while model .running :
82+ do_step ()
5583
56- def threaded_do_play (self ):
57- if self . thread is not None and self . thread .is_alive ():
84+ def threaded_do_play ():
85+ if thread is not None and thread .is_alive ():
5886 return
59- self . thread = threading .Thread (target = self . do_play )
60- self . thread .start ()
87+ thread . value = threading .Thread (target = do_play )
88+ thread .start ()
6189
62- def do_pause (self ):
63- if (self . thread is None ) or (not self . thread .is_alive ()):
90+ def do_pause ():
91+ if (thread is None ) or (not thread .is_alive ()):
6492 return
65- self . model .running = False
66- self . thread .join ()
93+ model .running = False
94+ thread .join ()
6795
68- def portray (self , g ):
96+ with solara .Row ():
97+ solara .Button (label = "Step" , color = "primary" , on_click = do_step )
98+ # This style is necessary so that the play widget has almost the same
99+ # height as typical Solara buttons.
100+ solara .Style (
101+ """
102+ .widget-play {
103+ height: 30px;
104+ }
105+ """
106+ )
107+ widgets .Play (
108+ value = 0 ,
109+ interval = play_interval ,
110+ repeat = True ,
111+ show_repeat = False ,
112+ on_value = on_value_play ,
113+ playing = playing .value ,
114+ on_playing = playing .set ,
115+ )
116+ solara .Markdown (md_text = f"**Step:** { current_step } " )
117+ # threaded_do_play is not used for now because it
118+ # doesn't work in Google colab. We use
119+ # ipywidgets.Play until it is fixed. The threading
120+ # version is definite a much better implementation,
121+ # if it works.
122+ # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
123+ # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
124+ # solara.Button(label="Reset", color="primary", on_click=do_reset)
125+
126+
127+ def split_model_params (model_params ):
128+ model_params_input = {}
129+ model_params_fixed = {}
130+ for k , v in model_params .items ():
131+ if check_param_is_fixed (v ):
132+ model_params_fixed [k ] = v
133+ else :
134+ model_params_input [k ] = v
135+ return model_params_input , model_params_fixed
136+
137+
138+ def check_param_is_fixed (param ):
139+ if not isinstance (param , dict ):
140+ return True
141+ if "type" not in param :
142+ return True
143+
144+
145+ def make_user_input (user_input , k , v ):
146+ if v ["type" ] == "SliderInt" :
147+ solara .SliderInt (
148+ v .get ("label" , "label" ),
149+ value = user_input ,
150+ min = v .get ("min" ),
151+ max = v .get ("max" ),
152+ step = v .get ("step" ),
153+ )
154+ elif v ["type" ] == "SliderFloat" :
155+ solara .SliderFloat (
156+ v .get ("label" , "label" ),
157+ value = user_input ,
158+ min = v .get ("min" ),
159+ max = v .get ("max" ),
160+ step = v .get ("step" ),
161+ )
162+ elif v ["type" ] == "Select" :
163+ solara .Select (
164+ v .get ("label" , "label" ),
165+ value = v .get ("value" ),
166+ values = v .get ("values" ),
167+ )
168+
169+
170+ def make_space (model , agent_portrayal ):
171+ def portray (g ):
69172 x = []
70173 y = []
71174 s = [] # size
@@ -79,7 +182,7 @@ def portray(self, g):
79182 # Is a single grid
80183 content = [content ]
81184 for agent in content :
82- data = self . agent_portrayal (agent )
185+ data = agent_portrayal (agent )
83186 x .append (i )
84187 y .append (j )
85188 if "size" in data :
@@ -93,159 +196,40 @@ def portray(self, g):
93196 out ["c" ] = c
94197 return out
95198
199+ space_fig = Figure ()
200+ space_ax = space_fig .subplots ()
201+ if isinstance (model .grid , mesa .space .NetworkGrid ):
202+ _draw_network_grid (model , space_ax , agent_portrayal )
203+ else :
204+ space_ax .scatter (** portray (model .grid ))
205+ space_ax .set_axis_off ()
206+ solara .FigureMatplotlib (space_fig )
207+
96208
97- def _draw_network_grid (viz , space_ax ):
98- graph = viz . model .grid .G
209+ def _draw_network_grid (model , space_ax , agent_portrayal ):
210+ graph = model .grid .G
99211 pos = nx .spring_layout (graph , seed = 0 )
100212 nx .draw (
101213 graph ,
102214 ax = space_ax ,
103215 pos = pos ,
104- ** viz . agent_portrayal (graph ),
216+ ** agent_portrayal (graph ),
105217 )
106218
107219
108- def make_space (viz ):
109- space_fig = Figure ()
110- space_ax = space_fig .subplots ()
111- if isinstance (viz .model .grid , mesa .space .NetworkGrid ):
112- _draw_network_grid (viz , space_ax )
113- else :
114- space_ax .scatter (** viz .portray (viz .model .grid ))
115- space_ax .set_axis_off ()
116- solara .FigureMatplotlib (space_fig , dependencies = [viz .model , viz .df ])
117-
118-
119- def make_plot (viz , measure ):
220+ def make_plot (model , measure ):
120221 fig = Figure ()
121222 ax = fig .subplots ()
122- ax .plot (viz .df .loc [:, measure ])
223+ df = model .datacollector .get_model_vars_dataframe ()
224+ ax .plot (df .loc [:, measure ])
123225 ax .set_ylabel (measure )
124226 # Set integer x axis
125227 ax .xaxis .set_major_locator (MaxNLocator (integer = True ))
126- solara .FigureMatplotlib (fig , dependencies = [ viz . model , viz . df ] )
228+ solara .FigureMatplotlib (fig )
127229
128230
129231def make_text (renderer ):
130- def function (viz ):
131- solara .Markdown (renderer (viz . model ))
232+ def function (model ):
233+ solara .Markdown (renderer (model ))
132234
133235 return function
134-
135-
136- def make_user_input (user_input , k , v ):
137- if v ["type" ] == "SliderInt" :
138- solara .SliderInt (
139- v .get ("label" , "label" ),
140- value = user_input ,
141- min = v .get ("min" ),
142- max = v .get ("max" ),
143- step = v .get ("step" ),
144- )
145- elif v ["type" ] == "SliderFloat" :
146- solara .SliderFloat (
147- v .get ("label" , "label" ),
148- value = user_input ,
149- min = v .get ("min" ),
150- max = v .get ("max" ),
151- step = v .get ("step" ),
152- )
153- elif v ["type" ] == "Select" :
154- solara .Select (
155- v .get ("label" , "label" ),
156- value = v .get ("value" ),
157- values = v .get ("values" ),
158- )
159-
160-
161- @solara .component
162- def MesaComponent (viz , space_drawer = None , play_interval = 400 ):
163- solara .Markdown (viz .name )
164-
165- # 1. User inputs
166- user_inputs = {}
167- for k , v in viz .model_params_input .items ():
168- user_input = solara .use_reactive (v ["value" ])
169- user_inputs [k ] = user_input .value
170- make_user_input (user_input , k , v )
171-
172- # 2. Model
173- def make_model ():
174- return viz .model_class (** user_inputs , ** viz .model_params_fixed )
175-
176- viz .model = solara .use_memo (make_model , dependencies = list (user_inputs .values ()))
177- viz .df , viz .set_df = solara .use_state (
178- viz .model .datacollector .get_model_vars_dataframe ()
179- )
180-
181- # 3. Buttons
182- playing = solara .use_reactive (False )
183-
184- def on_value_play (change ):
185- if viz .model .running :
186- viz .do_step ()
187- else :
188- playing .value = False
189-
190- with solara .Row ():
191- solara .Button (label = "Step" , color = "primary" , on_click = viz .do_step )
192- # This style is necessary so that the play widget has almost the same
193- # height as typical Solara buttons.
194- solara .Style (
195- """
196- .widget-play {
197- height: 30px;
198- }
199- """
200- )
201- widgets .Play (
202- value = 0 ,
203- interval = play_interval ,
204- repeat = True ,
205- show_repeat = False ,
206- on_value = on_value_play ,
207- playing = playing .value ,
208- on_playing = playing .set ,
209- )
210- solara .Markdown (md_text = f"**Step:** { viz .model .schedule .steps } " )
211- # threaded_do_play is not used for now because it
212- # doesn't work in Google colab. We use
213- # ipywidgets.Play until it is fixed. The threading
214- # version is definite a much better implementation,
215- # if it works.
216- # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
217- # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
218- # solara.Button(label="Reset", color="primary", on_click=do_reset)
219-
220- with solara .GridFixed (columns = 2 ):
221- # 4. Space
222- if space_drawer is None :
223- make_space (viz )
224- else :
225- space_drawer (viz )
226- # 5. Plots
227- for measure in viz .measures :
228- if callable (measure ):
229- # Is a custom object
230- measure (viz )
231- else :
232- make_plot (viz , measure )
233-
234-
235- # JupyterViz has to be a Solara component, so that each browser tabs runs in
236- # their own, separate simulation thread. See https://github.com/projectmesa/mesa/issues/856.
237- @solara .component
238- def JupyterViz (
239- model_class ,
240- model_params ,
241- measures = None ,
242- name = "Mesa Model" ,
243- agent_portrayal = None ,
244- space_drawer = None ,
245- play_interval = 400 ,
246- ):
247- return MesaComponent (
248- JupyterContainer (model_class , model_params , measures , name , agent_portrayal ),
249- space_drawer = space_drawer ,
250- play_interval = play_interval ,
251- )
0 commit comments