@@ -293,12 +293,23 @@ import pandas as pd
293293import plotly.express as px
294294import collections
295295
296+ def negative_1_if_count_is_odd (count ):
297+ # if this is an odd numbered entry in its bin, make its y coordinate negative
298+ # the y coordinate of the first entry is 0, so entries 3, 5, and 7 get negative y coordinates
299+ if count% 2 == 1 :
300+ return - 1
301+ else :
302+ return 1
303+
304+
305+
296306
297307def swarm (
298308 X_series ,
299309 point_size = 16 ,
300310 fig_width = 800 ,
301311 gap_multiplier = 1.2 ,
312+ center_even_groups = False
302313):
303314 # sorting will align columns in attractive arcs rather than having columns the vary unpredicatbly in the x-dimension
304315 X_series= X_series.copy().sort_values()
@@ -309,7 +320,7 @@ def swarm(
309320 # minimum X value to the maximum X value
310321 min_x = min (X_series)
311322 max_x = max (X_series)
312-
323+
313324 list_of_rows = []
314325 # we will count the number of points in each "bin" / vertical strip of the graph
315326 # to be able to assign a y-coordinate that avoids overlapping
@@ -319,33 +330,43 @@ def swarm(
319330 # assign this x_value to bin number
320331 # each bin is a vertical strip wide enough for one marker
321332 bin = (((fig_width* (x_val- min_x))/ (max_x- min_x)) // point_size)
322-
333+
323334 # update the count of dots in that strip
324335 bin_counter.update([bin ])
325-
326- # if this is an odd numbered entry in its bin, make its y coordinate negative
327- # the y coordinate of the first entry is 0, so entries 3, 5, and 7 get negative y coordinates
328- if bin_counter[bin ]% 2 == 1 :
329- negative_1_if_count_is_odd = - 1
330- else :
331- negative_1_if_count_is_odd = 1
336+
332337
333338 # the collision free y coordinate gives the items in a vertical bin
334339 # coordinates: 0, 1, -1, 2, -2, 3, -3 ... and so on to evenly spread
335340 # their locations above and below the y-axis (we'll make a correction below to deal with even numbers of entries)
336341 # we then scale this by the point_size*gap_multiplier to get a y coordinate in px
337342
338- collision_free_y_coordinate= (bin_counter[bin ]// 2 )* negative_1_if_count_is_odd* point_size* gap_multiplier
339- list_of_rows.append({" x" :x_val," y" :collision_free_y_coordinate," bin" :bin })
343+ collision_free_y_coordinate= (bin_counter[bin ]// 2 )* negative_1_if_count_is_odd(bin_counter[ bin ]) * point_size* gap_multiplier
344+ list_of_rows.append({" x" :x_val," y" :collision_free_y_coordinate," bin" :bin , " adj " : 0 })
340345
341346 # if the number of points is even,
342347 # move y-coordinates down to put an equal number of entries above and below the axis
348+ # this can sometimes break the collision avoidance routine, but makes small N outputs look better otherwise
349+ if center_even_groups:
350+ for row in list_of_rows:
351+ if bin_counter[row[" bin" ]]% 2 == 0 :
352+ row[" y" ]-= point_size* gap_multiplier/ 2
353+ row[" adj" ]= - point_size* gap_multiplier/ 2
354+
355+
343356 for row in list_of_rows:
344- if bin_counter[row[" bin" ]]% 2 == 0 :
345- row[" y" ]-= point_size* gap_multiplier/ 2
357+ bin = row[" bin" ]
358+ # see if we need to "look left" to avoid a possible collision
359+ for other_row in list_of_rows:
360+ if (other_row[" bin" ]== bin - 1 ):
361+ if (((other_row[" y" ]== row[" y" ]) or (other_row[" y" ]== row[" y" ]+ row[" adj" ]))
362+ and (((fig_width* (row[" x" ]- other_row[" x" ]))/ (max_x- min_x) // point_size) < 1 )):
363+ bin_counter.update([bin ])
364+ row[" y" ]= (bin_counter[bin ]// 2 )* negative_1_if_count_is_odd(bin_counter[bin ])* point_size* gap_multiplier+ row[" adj" ]
365+
366+
346367
347368 df = pd.DataFrame(list_of_rows)
348-
369+
349370 fig = px.scatter(
350371 df,
351372 x = " x" ,
@@ -370,9 +391,12 @@ def swarm(
370391
371392
372393df_iris = px.data.iris() # iris is a pandas DataFrame
373- fig = swarm(df_iris[" sepal_length" ])
394+ x = df_iris[" sepal_length" ]
395+ x2 = pd.Series([5.05 ])
396+ x = pd.concat([x,x2], ignore_index = True )
397+ fig = swarm(x)
398+ # fig = swarm(pd.Series([1,1.5, 1.78, 1.79,2,2,12]))
374399fig.show()
375-
376400```
377401
378402## Scatter and line plots with go.Scatter
0 commit comments