@@ -293,30 +293,32 @@ import pandas as pd
293293import plotly.express as px
294294import collections
295295
296- def negative_1_if_count_is_odd (count ):
297- # if this is an odd numbered entry in its bin, make its y coordinate negative
298- # the y coordinate of the first entry is 0, so entries 3, 5, and 7 get negative y coordinates
299- if count% 2 == 1 :
300- return - 1
301- else :
302- return 1
303-
304296
297+ def negative_1_if_count_is_odd (count ):
298+ # if this is an odd numbered entry in its bin, make its y coordinate negative
299+ # the y coordinate of the first entry is 0, so entries 3, 5, and 7 get
300+ # negative y coordinates
301+ if count % 2 == 1 :
302+ return - 1
303+ else :
304+ return 1
305305
306306
307307def swarm (
308308 X_series ,
309309 point_size = 16 ,
310- fig_width = 800 ,
310+ fig_width = 800 ,
311311 gap_multiplier = 1.2 ,
312- bin_fraction = 0.95 , # bin fraction slightly undersizes the bins to avoid collisions
313- ):
314- # sorting will align columns in attractive arcs rather than having columns the vary unpredicatbly in the x-dimension
315- X_series= X_series.copy().sort_values()
316-
312+ bin_fraction = 0.95 , # slightly undersizes the bins to avoid collisions
313+ ):
314+ # sorting will align columns in attractive c-shaped arcs rather than having
315+ # columns that vary unpredictably in the x-dimension.
316+ # We also exploit the fact that sorting means we see bins sequentially when
317+ # we add collision prevention offsets.
318+ X_series = X_series.copy().sort_values()
317319
318320 # we need to reason in terms of the marker size that is measured in px
319- # so we need to think about each x-coordinate as being a fraction of the way from the
321+ # so we need to think about each x-coordinate as being a fraction of the way from the
320322 # minimum X value to the maximum X value
321323 min_x = min (X_series)
322324 max_x = max (X_series)
@@ -329,79 +331,93 @@ def swarm(
329331 for x_val in X_series:
330332 # assign this x_value to bin number
331333 # each bin is a vertical strip slightly narrower than one marker
332-
333- bin = (((fig_width* bin_fraction* (x_val- min_x))/ (max_x- min_x)) // point_size)
334+ bin = (((fig_width* bin_fraction* (x_val- min_x))/ (max_x- min_x)) // point_size)
334335
335- # update the count of dots in that strip
336+ # update the count of dots in that strip
336337 bin_counter.update([bin ])
337338
339+ # remember the "y-slot" which tells us the number of points in this bin and is sufficient to compute the y coordinate unless there's a collision with the point to its left
340+ list_of_rows.append(
341+ {" x" : x_val, " y_slot" : bin_counter[bin ], " bin" : bin })
338342
339- # the collision free y coordinate gives the items in a vertical bin
340- # coordinates: 0, 1, -1, 2, -2, 3, -3 ... and so on to evenly spread
341- # their locations above and below the y-axis (we'll make a correction below to deal with even numbers of entries)
342- # we then scale this by the point_size*gap_multiplier to get a y coordinate in px
343-
344- collision_free_y_coordinate= (bin_counter[bin ]// 2 )* negative_1_if_count_is_odd(bin_counter[bin ])* point_size* gap_multiplier
345- list_of_rows.append({" x" :x_val," y" :collision_free_y_coordinate," bin" :bin })
346-
347-
348-
343+ # iterate through the points and "offset" any that are colliding with a
344+ # point to their left apply the offsets to all subsequent points in the same bin.
345+ # this arranges points in an attractive swarm c-curve where the points
346+ # toward the edges are (weakly) further right.
347+ bin = 0
348+ offset = 0
349349 for row in list_of_rows:
350- bin = row[" bin" ]
351- # see if we need to "look left" to avoid a possible collision
350+ if bin != row[" bin" ]:
351+ # we have moved to a new bin, so we need to reset the offset
352+ bin = row[" bin" ]
353+ offset = 0
354+ # see if we need to "look left" to avoid a possible collision
352355 for other_row in list_of_rows:
353- if (other_row[" bin" ]== bin - 1 ):
354- # "bubble" the entry up until we find a slot that avoids a collision
355- while ((other_row[" y" ]== row[" y" ])
356- and (((fig_width* (row[" x" ]- other_row[" x" ]))/ (max_x- min_x) // point_size) < 1 )):
357- print (row)
358- print (other_row)
359- print (((fig_width* (row[" x" ]- other_row[" x" ] ))/ (max_x- min_x) // point_size))
360-
361- print (" updating to fix collision" )
356+ if (other_row[" bin" ] == bin - 1 ):
357+ # "bubble" the entry up until we find a slot that avoids a collision
358+ while ((other_row[" y_slot" ] == row[" y_slot" ]+ offset)
359+ and (((fig_width* (row[" x" ]- other_row[" x" ]))/ (max_x- min_x)
360+ // point_size) < 1 )):
361+ offset += 1
362+ # update the bin count so we know whether the number of
363+ # *used* slots is even or odd
362364 bin_counter.update([bin ])
363- print (bin_counter[bin ])
364- row[" y" ]= (bin_counter[bin ]// 2 )* negative_1_if_count_is_odd(bin_counter[bin ])* point_size* gap_multiplier
365- print (row[" y" ])
366365
367- # if the number of points is even,
368- # move y-coordinates down to put an equal number of entries above and below the axis
366+ row[" y_slot" ] += offset
367+ # The collision free y coordinate gives the items in a vertical bin
368+ # y-coordinates to evenly spread their locations above and below the
369+ # y-axis (we'll make a correction below to deal with even numbers of
370+ # entries). For now, we'll assign 0, 1, -1, 2, -2, 3, -3 ... and so on.
371+ # We scale this by the point_size*gap_multiplier to get a y coordinate
372+ # in px.
373+ row[" y" ] = (row[" y_slot" ]// 2 ) * \
374+ negative_1_if_count_is_odd(row[" y_slot" ])* point_size* gap_multiplier
375+ print (row[" y" ])
376+
377+ # if the number of points is even, move y-coordinates down to put an equal
378+ # number of entries above and below the axis
369379 for row in list_of_rows:
370- if bin_counter[row[" bin" ]]% 2 == 0 :
371- row[" y" ]-= point_size* gap_multiplier/ 2
372-
380+ if bin_counter[row[" bin" ]] % 2 == 0 :
381+ row[" y" ] -= point_size* gap_multiplier/ 2
373382
374383 df = pd.DataFrame(list_of_rows)
375- # one way to make this code more flexible to e.g. handle multiple categories would be to return a list of "swarmified" y coordinates here
376- # you could then generate "swarmified" y coordinates for each category and add category specific offsets before scatterplotting them
384+ # One way to make this code more flexible to e.g. handle multiple categories
385+ # would be to return a list of "swarmified" y coordinates here and then plot
386+ # outside the function.
387+ # That generalization would let you "swarmify" y coordinates for each
388+ # category and add category specific offsets to put the each category in its
389+ # own row
377390
378391 fig = px.scatter(
379392 df,
380393 x = " x" ,
381394 y = " y" ,
382395 )
383- # we want to suppress the y coordinate in the hover value because the y-coordinate is irrelevant/misleading
396+ # we want to suppress the y coordinate in the hover value because the
397+ # y-coordinate is irrelevant/misleading
384398 fig.update_traces(
385399 marker_size = point_size,
386- # suppress the y coordinate because the y-coordinate is irrelevant
400+ # suppress the y coordinate because the y-coordinate is irrelevant
387401 hovertemplate = " <b>value</b>: %{x} " ,
388402 )
389- # we have to set the width and height because we aim to avoid icon collisions and we specify the icon size
390- # in the same units as the width and height
391- fig.update_layout(width = fig_width, height = (point_size* max (bin_counter.values())+ 200 ))
403+ # we have to set the width and height because we aim to avoid icon collisions
404+ # and we specify the icon size in the same units as the width and height
405+ fig.update_layout(width = fig_width, height = (
406+ point_size* max (bin_counter.values())+ 200 ))
392407 fig.update_yaxes(
393- showticklabels = False , # Turn off y-axis labels
394- ticks = ' ' , # Remove the ticks
395- title = " "
408+ showticklabels = False , # Turn off y-axis labels
409+ ticks = ' ' , # Remove the ticks
410+ title = " "
396411 )
397412 return fig
398413
399414
400-
401- df_iris = px.data.iris() # iris is a pandas DataFrame
402- x = df_iris[" sepal_length" ]
403- fig = swarm(x)
404- fig.show()
415+ df = px.data.iris() # iris is a pandas DataFrame
416+ fig = swarm(df[" sepal_length" ])
417+ # here's a more interesting test case for collision avoidance:
418+ # fig = swarm(pd.Series([1, 1.5, 1.78, 1.79, 1.85, 2,
419+ # 2, 2, 2, 3, 3, 2.05, 2.1, 2.2, 2.5, 12]))
420+ fig.show()
405421```
406422
407423## Scatter and line plots with go.Scatter
0 commit comments