2121from matplotlib .collections import LineCollection
2222
2323
24- def colored_line (x , y , c , ax , ** lc_kwargs ):
24+ def colored_line (x , y , c , ax = None , scalex = True , scaley = True , ** lc_kwargs ):
2525 """
2626 Plot a line with a color specified along the line by a third value.
2727
@@ -36,9 +36,12 @@ def colored_line(x, y, c, ax, **lc_kwargs):
3636 The horizontal and vertical coordinates of the data points.
3737 c : array-like
3838 The color values, which should be the same size as x and y.
39- ax : Axes
40- Axis object on which to plot the colored line.
41- **lc_kwargs
39+ ax : matplotlib.axes.Axes, optional
40+ The axes to plot on. If not provided, the current axes will be used.
41+ scalex, scaley : bool
42+ These parameters determine if the view limits are adapted to the data limits.
43+ The values are passed on to autoscale_view.
44+ **lc_kwargs : Any
4245 Any additional arguments to pass to matplotlib.collections.LineCollection
4346 constructor. This should not include the array keyword argument because
4447 that is set to the color argument. If provided, it will be overridden.
@@ -49,36 +52,35 @@ def colored_line(x, y, c, ax, **lc_kwargs):
4952 The generated line collection representing the colored line.
5053 """
5154 if "array" in lc_kwargs :
52- warnings .warn ('The provided "array" keyword argument will be overridden' )
55+ warnings .warn (
56+ 'The provided "array" keyword argument will be overridden' ,
57+ UserWarning ,
58+ stacklevel = 2 ,
59+ )
5360
54- # Default the capstyle to butt so that the line segments smoothly line up
55- default_kwargs = {"capstyle" : "butt" }
56- default_kwargs .update (lc_kwargs )
57-
58- # Compute the midpoints of the line segments. Include the first and last points
59- # twice so we don't need any special syntax later to handle them.
60- x = np .asarray (x )
61- y = np .asarray (y )
62- x_midpts = np .hstack ((x [0 ], 0.5 * (x [1 :] + x [:- 1 ]), x [- 1 ]))
63- y_midpts = np .hstack ((y [0 ], 0.5 * (y [1 :] + y [:- 1 ]), y [- 1 ]))
64-
65- # Determine the start, middle, and end coordinate pair of each line segment.
66- # Use the reshape to add an extra dimension so each pair of points is in its
67- # own list. Then concatenate them to create:
68- # [
69- # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70- # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
61+ xy = np .stack ((x , y ), axis = - 1 )
62+ xy_mid = np .concat (
63+ (xy [0 , :][None , :], (xy [:- 1 , :] + xy [1 :, :]) / 2 , xy [- 1 , :][None , :]), axis = 0
64+ )
65+ segments = np .stack ((xy_mid [:- 1 , :], xy , xy_mid [1 :, :]), axis = - 2 )
66+ # Note that segments is [
67+ # [[x[0], y[0]], [x[0], y[0]], [mean(x[0], x[1]), mean(y[0], y[1])]],
68+ # [[mean(x[0], x[1]), mean(y[0], y[1])], [x[1], y[1]],
69+ # [mean(x[1], x[2]), mean(y[1], y[2])]],
7170 # ...
71+ # [[mean(x[-2], x[-1]), mean(y[-2], y[-1])], [x[-1], y[-1]], [x[-1], y[-1]]]
7272 # ]
73- coord_start = np .column_stack ((x_midpts [:- 1 ], y_midpts [:- 1 ]))[:, np .newaxis , :]
74- coord_mid = np .column_stack ((x , y ))[:, np .newaxis , :]
75- coord_end = np .column_stack ((x_midpts [1 :], y_midpts [1 :]))[:, np .newaxis , :]
76- segments = np .concatenate ((coord_start , coord_mid , coord_end ), axis = 1 )
7773
78- lc = LineCollection (segments , ** default_kwargs )
79- lc .set_array (c ) # set the colors of each segment
74+ lc_kwargs ["array" ] = c
75+ lc = LineCollection (segments , ** lc_kwargs )
76+
77+ # Plot the line collection to the axes
78+ ax = ax or plt .gca ()
79+ ax .add_collection (lc )
80+ ax .autoscale_view (scalex = scalex , scaley = scaley )
8081
81- return ax .add_collection (lc )
82+ # Return the LineCollection object
83+ return lc
8284
8385
8486# -------------- Create and show plot --------------
@@ -93,11 +95,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
9395lines = colored_line (x , y , color , ax1 , linewidth = 10 , cmap = "plasma" )
9496fig1 .colorbar (lines ) # add a color legend
9597
96- # Set the axis limits and tick positions
97- ax1 .set_xlim (- 1 , 1 )
98- ax1 .set_ylim (- 1 , 1 )
99- ax1 .set_xticks ((- 1 , 0 , 1 ))
100- ax1 .set_yticks ((- 1 , 0 , 1 ))
10198ax1 .set_title ("Color at each point" )
10299
103100plt .show ()
0 commit comments