88from itertools import cycle
99
1010from specparam .data .utils import get_periodic_labels , get_band_labels
11+ from specparam .utils .data import compute_presence
1112from specparam .plts .utils import savefig
1213from specparam .plts .templates import plot_param_over_time_yshade
1314from specparam .plts .settings import PARAM_COLORS
@@ -45,13 +46,13 @@ def plot_event_model(event_model, **plot_kwargs):
4546 n_bands = len (pe_labels ['cf' ])
4647
4748 has_knee = 'knee' in event_model .event_time_results .keys ()
48- height_ratios = [1 ] * (3 if has_knee else 2 ) + [0.25 , 1 , 1 , 1 ] * n_bands + [0.25 ] + [1 , 1 ]
49+ height_ratios = [1 ] * (3 if has_knee else 2 ) + [0.25 , 1 , 1 , 1 , 1 ] * n_bands + [0.25 ] + [1 , 1 ]
4950
5051 axes = plot_kwargs .pop ('axes' , None )
5152 if axes is None :
52- _ , axes = plt .subplots ((4 if has_knee else 3 ) + (n_bands * 4 ) + 2 , 1 ,
53+ _ , axes = plt .subplots ((4 if has_knee else 3 ) + (n_bands * 5 ) + 2 , 1 ,
5354 gridspec_kw = {'hspace' : 0.1 , 'height_ratios' : height_ratios },
54- figsize = plot_kwargs .pop ('figsize' , [10 , 4 + 4 * n_bands ]))
55+ figsize = plot_kwargs .pop ('figsize' , [10 , 4 + 5 * n_bands ]))
5556 axes = cycle (axes )
5657
5758 xlim = [0 , event_model .n_time_windows - 1 ]
@@ -74,6 +75,10 @@ def plot_event_model(event_model, **plot_kwargs):
7475 label = plabel .upper (), drop_xticks = True , add_xlabel = False , xlim = xlim ,
7576 title = 'Periodic Parameters - ' + band_labels [band_ind ] if plabel == 'cf' else None ,
7677 color = PARAM_COLORS [plabel ], ax = next (axes ))
78+ plot_param_over_time_yshade (\
79+ None , compute_presence (event_model .event_time_results [pe_labels [plabel ][band_ind ]]),
80+ label = 'Presence' , drop_xticks = True , add_xlabel = False , xlim = xlim ,
81+ color = PARAM_COLORS ['presence' ], ax = next (axes ))
7782 next (axes ).axis ('off' )
7883
7984 # 03: goodness of fit
0 commit comments