1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
|
"""
===========
Legend Demo
===========
There are many ways to create and customize legends in Matplotlib. Below
we'll show a few examples for how to do so.
First we'll show off how to make a legend for specific lines.
"""
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.collections as mcol
from matplotlib.legend_handler import HandlerLineCollection, HandlerTuple
from matplotlib.lines import Line2D
t1 = np.arange(0.0, 2.0, 0.1)
t2 = np.arange(0.0, 2.0, 0.01)
fig, ax = plt.subplots()
# note that plot returns a list of lines. The "l1, = plot" usage
# extracts the first element of the list into l1 using tuple
# unpacking. So l1 is a Line2D instance, not a sequence of lines
l1, = ax.plot(t2, np.exp(-t2))
l2, l3 = ax.plot(t2, np.sin(2 * np.pi * t2), '--o', t1, np.log(1 + t1), '.')
l4, = ax.plot(t2, np.exp(-t2) * np.sin(2 * np.pi * t2), 's-.')
ax.legend((l2, l4), ('oscillatory', 'damped'), loc='upper right', shadow=True)
ax.set_xlabel('time')
ax.set_ylabel('volts')
ax.set_title('Damped oscillation')
plt.show()
# %%
# Next we'll demonstrate plotting more complex labels.
x = np.linspace(0, 1)
fig, (ax0, ax1) = plt.subplots(2, 1)
# Plot the lines y=x**n for n=1..4.
for n in range(1, 5):
ax0.plot(x, x**n, label=f"{n=}")
leg = ax0.legend(loc="upper left", bbox_to_anchor=[0, 1],
ncols=2, shadow=True, title="Legend", fancybox=True)
leg.get_title().set_color("red")
# Demonstrate some more complex labels.
ax1.plot(x, x**2, label="multi\nline")
half_pi = np.linspace(0, np.pi / 2)
ax1.plot(np.sin(half_pi), np.cos(half_pi), label=r"$\frac{1}{2}\pi$")
ax1.plot(x, 2**(x**2), label="$2^{x^2}$")
ax1.legend(shadow=True, fancybox=True)
plt.show()
# %%
# Here we attach legends to more complex plots.
fig, axs = plt.subplots(3, 1, layout="constrained")
top_ax, middle_ax, bottom_ax = axs
top_ax.bar([0, 1, 2], [0.2, 0.3, 0.1], width=0.4, label="Bar 1",
align="center")
top_ax.bar([0.5, 1.5, 2.5], [0.3, 0.2, 0.2], color="red", width=0.4,
label="Bar 2", align="center")
top_ax.legend()
middle_ax.errorbar([0, 1, 2], [2, 3, 1], xerr=0.4, fmt="s", label="test 1")
middle_ax.errorbar([0, 1, 2], [3, 2, 4], yerr=0.3, fmt="o", label="test 2")
middle_ax.errorbar([0, 1, 2], [1, 1, 3], xerr=0.4, yerr=0.3, fmt="^",
label="test 3")
middle_ax.legend()
bottom_ax.stem([0.3, 1.5, 2.7], [1, 3.6, 2.7], label="stem test")
bottom_ax.legend()
plt.show()
# %%
# Now we'll showcase legend entries with more than one legend key.
fig, (ax1, ax2) = plt.subplots(2, 1, layout='constrained')
# First plot: two legend keys for a single entry
p1 = ax1.scatter([1], [5], c='r', marker='s', s=100)
p2 = ax1.scatter([3], [2], c='b', marker='o', s=100)
# `plot` returns a list, but we want the handle - thus the comma on the left
p3, = ax1.plot([1, 5], [4, 4], 'm-d')
# Assign two of the handles to the same legend entry by putting them in a tuple
# and using a generic handler map (which would be used for any additional
# tuples of handles like (p1, p3)).
l = ax1.legend([(p1, p3), p2], ['two keys', 'one key'], scatterpoints=1,
numpoints=1, handler_map={tuple: HandlerTuple(ndivide=None)})
# Second plot: plot two bar charts on top of each other and change the padding
# between the legend keys
x_left = [1, 2, 3]
y_pos = [1, 3, 2]
y_neg = [2, 1, 4]
rneg = ax2.bar(x_left, y_neg, width=0.5, color='w', hatch='///', label='-1')
rpos = ax2.bar(x_left, y_pos, width=0.5, color='k', label='+1')
# Treat each legend entry differently by using specific `HandlerTuple`s
l = ax2.legend([(rpos, rneg), (rneg, rpos)], ['pad!=0', 'pad=0'],
handler_map={(rpos, rneg): HandlerTuple(ndivide=None),
(rneg, rpos): HandlerTuple(ndivide=None, pad=0.)})
plt.show()
# %%
# Finally, it is also possible to write custom classes that define
# how to stylize legends.
class HandlerDashedLines(HandlerLineCollection):
"""
Custom Handler for LineCollection instances.
"""
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
# figure out how many lines there are
numlines = len(orig_handle.get_segments())
xdata, xdata_marker = self.get_xdata(legend, xdescent, ydescent,
width, height, fontsize)
leglines = []
# divide the vertical space where the lines will go
# into equal parts based on the number of lines
ydata = np.full_like(xdata, height / (numlines + 1))
# for each line, create the line at the proper location
# and set the dash pattern
for i in range(numlines):
legline = Line2D(xdata, ydata * (numlines - i) - ydescent)
self.update_prop(legline, orig_handle, legend)
# set color, dash pattern, and linewidth to that
# of the lines in linecollection
try:
color = orig_handle.get_colors()[i]
except IndexError:
color = orig_handle.get_colors()[0]
try:
dashes = orig_handle.get_dashes()[i]
except IndexError:
dashes = orig_handle.get_dashes()[0]
try:
lw = orig_handle.get_linewidths()[i]
except IndexError:
lw = orig_handle.get_linewidths()[0]
if dashes[1] is not None:
legline.set_dashes(dashes[1])
legline.set_color(color)
legline.set_transform(trans)
legline.set_linewidth(lw)
leglines.append(legline)
return leglines
x = np.linspace(0, 5, 100)
fig, ax = plt.subplots()
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:5]
styles = ['solid', 'dashed', 'dashed', 'dashed', 'solid']
for i, color, style in zip(range(5), colors, styles):
ax.plot(x, np.sin(x) - .1 * i, c=color, ls=style)
# make proxy artists
# make list of one line -- doesn't matter what the coordinates are
line = [[(0, 0)]]
# set up the proxy artist
lc = mcol.LineCollection(5 * line, linestyles=styles, colors=colors)
# create the legend
ax.legend([lc], ['multi-line'], handler_map={type(lc): HandlerDashedLines()},
handlelength=2.5, handleheight=3)
plt.show()
|