@@ -1141,6 +1141,7 @@ def get_orderings(args, grouper, grouped):
11411141 """
11421142 orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
11431143 group_names = []
1144+ group_values = {}
11441145 for group_name in grouped .groups :
11451146 if len (grouper ) == 1 :
11461147 group_name = (group_name ,)
@@ -1154,6 +1155,7 @@ def get_orderings(args, grouper, grouped):
11541155 for val in uniques :
11551156 if val not in orders [col ]:
11561157 orders [col ].append (val )
1158+ group_values [col ] = sorted (uniques , key = orders [col ].index )
11571159
11581160 for i , col in reversed (list (enumerate (grouper ))):
11591161 if col != one_group :
@@ -1162,7 +1164,7 @@ def get_orderings(args, grouper, grouped):
11621164 key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
11631165 )
11641166
1165- return orders , group_names
1167+ return orders , group_names , group_values
11661168
11671169
11681170def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
@@ -1174,16 +1176,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11741176 grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
11751177 grouped = args ["data_frame" ].groupby (grouper , sort = False )
11761178
1177- orders , sorted_group_names = get_orderings (args , grouper , grouped )
1179+ orders , sorted_group_names , sorted_group_values = get_orderings (
1180+ args , grouper , grouped
1181+ )
1182+
1183+ col_labels = []
1184+ row_labels = []
1185+
1186+ for m in grouped_mappings :
1187+ if m .grouper :
1188+ if m .facet == "col" :
1189+ prefix = get_label (args , args ["facet_col" ]) + "="
1190+ col_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1191+ if m .facet == "row" :
1192+ prefix = get_label (args , args ["facet_row" ]) + "="
1193+ row_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1194+ for val in sorted_group_values [m .grouper ]:
1195+ if val not in m .val_map :
1196+ m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
11781197
11791198 subplot_type = _subplot_type_for_trace_type (constructor ().type )
11801199
11811200 trace_names_by_frame = {}
11821201 frames = OrderedDict ()
11831202 trendline_rows = []
11841203 nrows = ncols = 1
1185- col_labels = []
1186- row_labels = []
11871204 trace_name_labels = None
11881205 for group_name in sorted_group_names :
11891206 group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
@@ -1281,10 +1298,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12811298 # Find row for trace, handling facet_row and marginal_x
12821299 if m .facet == "row" :
12831300 row = m .val_map [val ]
1284- if args ["facet_row" ] and len (row_labels ) < row :
1285- row_labels .append (
1286- get_label (args , args ["facet_row" ]) + "=" + str (val )
1287- )
12881301 else :
12891302 if (
12901303 bool (args .get ("marginal_x" , False ))
@@ -1298,10 +1311,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12981311 # Find col for trace, handling facet_col and marginal_y
12991312 if m .facet == "col" :
13001313 col = m .val_map [val ]
1301- if args ["facet_col" ] and len (col_labels ) < col :
1302- col_labels .append (
1303- get_label (args , args ["facet_col" ]) + "=" + str (val )
1304- )
13051314 if facet_col_wrap : # assumes no facet_row, no marginals
13061315 row = 1 + ((col - 1 ) // facet_col_wrap )
13071316 col = 1 + ((col - 1 ) % facet_col_wrap )
0 commit comments