2121from matplotlib .collections import LineCollection
2222
2323
24- def colored_line (x , y , c , ax , ** lc_kwargs ):
24+ def colored_line (x , y , c , ax = None , ** lc_kwargs ):
2525 """
2626 Plot a line with a color specified along the line by a third value.
2727
@@ -36,8 +36,8 @@ def colored_line(x, y, c, ax, **lc_kwargs):
3636 The horizontal and vertical coordinates of the data points.
3737 c : array-like
3838 The color values, which should be the same size as x and y.
39- ax : Axes
40- Axis object on which to plot the colored line .
39+ ax : matplotlib.axes. Axes, optional
40+ The axes to plot on. If not provided, the current axes will be used .
4141 **lc_kwargs
4242 Any additional arguments to pass to matplotlib.collections.LineCollection
4343 constructor. This should not include the array keyword argument because
@@ -49,36 +49,32 @@ def colored_line(x, y, c, ax, **lc_kwargs):
4949 The generated line collection representing the colored line.
5050 """
5151 if "array" in lc_kwargs :
52- warnings .warn ('The provided "array" keyword argument will be overridden' )
52+ warnings .warn (
53+ 'The provided "array" keyword argument will be overridden' ,
54+ UserWarning ,
55+ stacklevel = 2 ,
56+ )
57+
58+ xy = np .stack ((x , y ), axis = - 1 )
59+ xy_mid = np .concat (
60+ (xy [0 , :][None , :], (xy [:- 1 , :] + xy [1 :, :]) / 2 , xy [- 1 , :][None , :]), axis = 0
61+ )
62+ segments = np .stack ((xy_mid [:- 1 , :], xy , xy_mid [1 :, :]), axis = - 2 )
63+ # Note that
64+ # segments[0, :, :] is [xy[0, :], xy[0, :], (xy[0, :] + xy[1, :]) / 2]
65+ # segments[i, :, :] is [(xy[i - 1, :] + xy[i, :]) / 2, xy[i, :],
66+ # (xy[i, :] + xy[i + 1, :]) / 2] if i not in {0, len(x) - 1}
67+ # segments[-1, :, :] is [(xy[-2, :] + xy[-1, :]) / 2, xy[-1, :], xy[-1, :]]
68+
69+ lc_kwargs ["array" ] = c
70+ lc = LineCollection (segments , ** lc_kwargs )
5371
54- # Default the capstyle to butt so that the line segments smoothly line up
55- default_kwargs = {"capstyle" : "butt" }
56- default_kwargs .update (lc_kwargs )
57-
58- # Compute the midpoints of the line segments. Include the first and last points
59- # twice so we don't need any special syntax later to handle them.
60- x = np .asarray (x )
61- y = np .asarray (y )
62- x_midpts = np .hstack ((x [0 ], 0.5 * (x [1 :] + x [:- 1 ]), x [- 1 ]))
63- y_midpts = np .hstack ((y [0 ], 0.5 * (y [1 :] + y [:- 1 ]), y [- 1 ]))
64-
65- # Determine the start, middle, and end coordinate pair of each line segment.
66- # Use the reshape to add an extra dimension so each pair of points is in its
67- # own list. Then concatenate them to create:
68- # [
69- # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70- # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
71- # ...
72- # ]
73- coord_start = np .column_stack ((x_midpts [:- 1 ], y_midpts [:- 1 ]))[:, np .newaxis , :]
74- coord_mid = np .column_stack ((x , y ))[:, np .newaxis , :]
75- coord_end = np .column_stack ((x_midpts [1 :], y_midpts [1 :]))[:, np .newaxis , :]
76- segments = np .concatenate ((coord_start , coord_mid , coord_end ), axis = 1 )
77-
78- lc = LineCollection (segments , ** default_kwargs )
79- lc .set_array (c ) # set the colors of each segment
72+ # Plot the line collection to the axes
73+ ax = ax or plt .gca ()
74+ ax .add_collection (lc )
75+ ax .autoscale_view ()
8076
81- return ax . add_collection ( lc )
77+ return lc
8278
8379
8480# -------------- Create and show plot --------------
@@ -93,11 +89,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
9389lines = colored_line (x , y , color , ax1 , linewidth = 10 , cmap = "plasma" )
9490fig1 .colorbar (lines ) # add a color legend
9591
96- # Set the axis limits and tick positions
97- ax1 .set_xlim (- 1 , 1 )
98- ax1 .set_ylim (- 1 , 1 )
99- ax1 .set_xticks ((- 1 , 0 , 1 ))
100- ax1 .set_yticks ((- 1 , 0 , 1 ))
10192ax1 .set_title ("Color at each point" )
10293
10394plt .show ()
0 commit comments