@@ -1779,7 +1779,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17791779 else args ["geojson" ].__geo_interface__
17801780 )
17811781
1782- # Compute marginal attribute
1782+ # Compute marginal attribute: copy to appropriate marginal_*
17831783 if "marginal" in args :
17841784 position = "marginal_x" if args ["orientation" ] == "v" else "marginal_y"
17851785 other_position = "marginal_x" if args ["orientation" ] == "h" else "marginal_y"
@@ -1879,6 +1879,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18791879
18801880 col_labels = []
18811881 row_labels = []
1882+ nrows = ncols = 1
18821883 for m in grouped_mappings :
18831884 if m .grouper not in sorted_group_values :
18841885 m .val_map ["" ] = m .sequence [0 ]
@@ -1887,9 +1888,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18871888 if m .facet == "col" :
18881889 prefix = get_label (args , args ["facet_col" ]) + "="
18891890 col_labels = [prefix + str (s ) for s in sorted_values ]
1891+ ncols = len (col_labels )
18901892 if m .facet == "row" :
18911893 prefix = get_label (args , args ["facet_row" ]) + "="
18921894 row_labels = [prefix + str (s ) for s in sorted_values ]
1895+ nrows = len (row_labels )
18931896 for val in sorted_values :
18941897 if val not in m .val_map : # always False if it's an IdentityMap
18951898 m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
@@ -1899,8 +1902,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18991902 trace_names_by_frame = {}
19001903 frames = OrderedDict ()
19011904 trendline_rows = []
1902- nrows = ncols = 1
19031905 trace_name_labels = None
1906+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
19041907 for group_name in sorted_group_names :
19051908 group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
19061909 mapping_labels = OrderedDict ()
@@ -1981,14 +1984,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19811984 row = m .val_map [val ]
19821985 else :
19831986 if (
1984- bool (args .get ("marginal_x" , False ))
1985- and trace_spec .marginal != "x"
1987+ bool (args .get ("marginal_x" , False )) # there is a marginal
1988+ and trace_spec .marginal != "x" # and we're not it
19861989 ):
19871990 row = 2
19881991 else :
19891992 row = 1
19901993
1991- facet_col_wrap = args .get ("facet_col_wrap" , 0 )
19921994 # Find col for trace, handling facet_col and marginal_y
19931995 if m .facet == "col" :
19941996 col = m .val_map [val ]
@@ -2001,11 +2003,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20012003 else :
20022004 col = 1
20032005
2004- nrows = max (nrows , row )
20052006 if row > 1 :
20062007 trace ._subplot_row = row
20072008
2008- ncols = max (ncols , col )
20092009 if col > 1 :
20102010 trace ._subplot_col = col
20112011 if (
@@ -2064,6 +2064,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20642064 ):
20652065 layout_patch ["legend" ]["itemsizing" ] = "constant"
20662066
2067+ if facet_col_wrap :
2068+ nrows = 1 + ncols // facet_col_wrap
2069+ ncols = ncols if ncols < facet_col_wrap else facet_col_wrap
2070+
2071+ if args .get ("marginal_x" ):
2072+ nrows += 1
2073+
2074+ if args .get ("marginal_y" ):
2075+ ncols += 1
2076+
20672077 fig = init_figure (
20682078 args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
20692079 )
0 commit comments