# | include: false
import numpy as np
import pandas as pd
256852) np.random.seed(
Plotting (Good)
What to expect in this chapter
In the previous chapter, we saw how you could use Matplotlib to produce simple, decent-looking plots. However, we haven’t really (barely) tapped the full power of what Matplotlib can do. For this, I need to introduce you to a different way of speaking to Matplotlib. So far, the ‘dialect’ we have used to talk to Matplotlib is called the Matlab-like pyplot(plt
) interface. From here onward, I will show you how to use the other, more powerful ‘dialect’ called the Object Oriented (OO) interface. This way of talking to Matplotlib gives us a more nuanced control over what is going on by allowing us to manipulate the various axes easily.
1 Some nomenclature
Before going ahead, let’s distinguish between a Matplotlib figure and an axis.
A figure is simple; it is the full canvas you use to draw stuff on. An axis is the individual mathematical axes we use for plotting. So, one figure can have multiple axes, as shown below, where we have a (single) figure with four axes.
By the way, you had already encountered a situation with multiple axes in the last chapter when we used twinx()
. It is not uncommon to struggle with the concept of axes; but don’t worry; it will become clearer as we work through this chapter.
2 Comparing the two ‘dialects’
Let me create the same plot using both pyplot
and OO ‘dialects’ so that you can see how the latter works. The plot I will be creating is shown below.
First, let’s generate some data to plot.
= np.linspace(-np.pi, np.pi, num=100)
x = np.cos(x)
cos_x = np.sin(x) sin_x
Now the comparison; remember that both sets of code will produce the same plot.
='cos x')
plt.plot(x, cos_x, label='sin x')
plt.plot(x, sin_x, label
plt.legend()= plt.subplots(nrows=1, ncols=1)
fig, ax ='cos x')
ax.plot(x, cos_x, label='sin x')
ax.plot(x, sin_x, label ax.legend()
pyplot
Interface
OO Interface
For the OO interface, we have to start by using subplots()
to ask Matplotlib to create a figure and an axis. Matplotlib obliges and gives us a figure (fig
) and an axis (ax
).
Although I have used the variables fig
and ax
you are free to call them what you like. But this is what is commonly used in the documentation. In this example, I need only one column and one row. But, if I want, I can ask for a grid like in the plot right at the top.
Yes, the OO looks more complicated than the pyplot
version. But, it offers so much freedom that it is worth learning it for more demanding, complex plots. You will see this soon.
Remember to use the pyplot interface for quick and dirty plots and the OO interface for more complex plots that demand control and finesse.
3 What is OO ax
, really?
The code below creates the (crude) plot shown on the right. Let’s look at what is happening so that we can understand what is going on.
= plt.subplots(nrows=2, ncols=1)
fig, ax
0].plot(x, cos_x, label='cos x')
ax[1].plot(x, sin_x, label='sin x')
ax[
0].legend()
ax[1].legend() ax[
To get the above plot, we must ask for two rows (nrows=2
) and one column (ncols=1
). We do this by using subplots()
= plt.subplots(ncols=1, nrows=2) fig, ax
This should give me two axes so that I can plot in both panes. Let’s quickly check a few more details about ax
.
What is
ax
?type(ax)
<class 'numpy.ndarray'>
So
ax
is a NumPy array!What size is
ax
?ax.shape
(2,)
As expected,
ax
has two ‘things’.What is contained in
ax
?type(ax[0])
<class 'matplotlib.axes._axes.Axes'>
This is a Matplotlib axis.
4 A complete OO example
The following is a simple example that creates a nicer, tweaked version of the previous plot.
= plt.subplots(nrows=2, ncols=1,
fig, ax =(5, 5),
figsize=True)
sharex0].plot(x, cos_x, label=r'$\cos(x)$')
ax[0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].plot(x, sin_x, label=r'$\sin(x)$')
ax[1].fill_between(x, 0, sin_x, alpha=.25)
ax[
for a in ax:
a.legend()=.25)
a.grid(alpha'$y$', rotation=0)
a.set_ylabel(
1].set_xlabel('$x$')
ax[
r'$\sin(x)$ and $\cos(x)$')
fig.suptitle( fig.tight_layout()
Now, let me take you this code and try to explain what is happening.
= plt.subplots(nrows=2, ncols=1,
fig, ax =(5, 5),
figsize=True)
sharex
0].plot(x, cos_x, label=r'$\cos(x)$')
ax[0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].plot(x, sin_x, label=r'$\sin(x)$')
ax[1].fill_between(x, 0, sin_x, alpha=.25)
ax[
for a in ax:
a.legend()=.25)
a.grid(alpha'$y$', rotation=0)
a.set_ylabel(
1].set_xlabel('$x$')
ax[
r'$\sin(x)$ and $\cos(x)$')
fig.suptitle(
fig.tight_layout()
- Set up the axes and figure
- Get two axes ready to plot in two rows in one column.
- Change the size of the figure by specifying a figure size (
figsize
). - Ask that the plots share the \(x\)-axis using
sharex
.
- Since
ax
is a NumPy array with two axes, we can index them using0
and1
. Then we justplot
to that axis. - Use
fill_between()
to fill the plots.- I have again just indexed the NumPy array to access the axes.
- Draw the legends
- As in the previous example, we can do this one axis at a time. However, a more sensible way to do this is with a
for
loop that iterates through the items inax
- Let’s also add a grid to each plot and set the label.
- As in the previous example, we can do this one axis at a time. However, a more sensible way to do this is with a
- We are sharing the \(x\)-axis. So, we only need to label the lowest plot. This has the index 1.
- Let’s add a super title to the figure (notplot).
- Finally, let’s ask Matplotlib to make any necessary adjustments to the layout to make our plot look nice by calling
tight_layout()
. It would help if you convinced yourself of the utility oftight_layout()
by producing the plot with and without it.
Unfortunately, the pyplot
and OO interfaces don’t use identical function names. For example, pyplot
used xlabel()
to set the \(x\) label, but OO uses set_xlabel()
. Annoying!
5 Multiple rows and columns
Now, I will show you how to work with multiple rows and columns that form a grid of plots like the one on the left. However, let’s start by using the figure on the right.
I have intentionally kept it simple (only setting the title and plotting some lines) so that we can identify how ax
works.
= plt.subplots(nrows=2, ncols=2,
fig, ax =(5, 5),
figsize='col', sharey='col')
sharex
# Some variables to access the axes and improve readabilty
= ax.flatten()
top_left, top_right, bottom_left, bottom_right
'Top Left')
top_left.set_title('Top Right')
top_right.set_title('Bottom Left')
bottom_left.set_title('Bottom Right')
bottom_right.set_title(
=1, xmin=0, xmax=4)
top_left.hlines(y=2, xmin=0, xmax=5)
top_right.hlines(y=3, xmin=0, xmax=4)
bottom_left.hlines(y=4, xmin=0, xmax=5)
bottom_right.hlines(y
for a in ax.flatten():
=.25)
a.grid(alpha
plt.tight_layout()
Using ax
I create a figure and axes using:
= plt.subplots(nrows=2, ncols=2,
fig, ax =(5, 5),
figsize='col', sharey='row') sharex
The most important thing you must understand is how to use ax
.
We know there must be four axes; but how is ax
structured? Let’s look at its shape.
ax.shape
(2, 2)
So, ax
is organised as you see in the figure, as a 2 x 2
array. So, I can access each of the axes as follows:
0, 0].set_title('Top Left')
ax[0, 1].set_title('Top Right')
ax[1, 0].set_title('Bottom Left')
ax[1, 1].set_title('Bottom Right') ax[
This is a perfectly valid way to use ax
. However, when you have to tweak each axis separately, I find it easy to use a familiar variable. I can do this by:
=ax[0, 0]
top_left=ax[0, 1]
top_right=ax[1, 0]
bottom_left=ax[1, 1] bottom_right
You can also use:
= ax.flatten() top_left, top_right, bottom_left, bottom_right
flatten()
takes the 2D array and ‘flattens’(dah) it into a 1D array; unpacking takes care of the assignments.
Oh, I forgot to show you how I drew the lines. I used the very useful Matplotlib function hlines()
:
=1, xmin=0, xmax=5)
top_left.hlines(y=2, xmin=0, xmax=5)
top_right.hlines(y=3, xmin=0, xmax=5)
bottom_left.hlines(y=4, xmin=0, xmax=5) bottom_right.hlines(y
5.2 Accessing all axes
You will often want to apply changes to all the axes, like in the case of the grid. You can do this by
=.25)
top_left.grid(alpha=.25)
top_right.grid(alpha=.25)
bottom_left.grid(alpha=.25) bottom_right.grid(alpha
But this is inefficient and requires a lot of work. It is much nicer to use a for
loop.
for a in ax.flatten():
=.25) a.grid(alpha
6 Other useful plots
In this section, I will quickly show you some useful plots we can generate with Matplotlib. I will also use a few different plotting styles I commonly use so that you can get a feel for how to change styles.
6.1 Histograms
A histogram is a valuable tool for showing distributions of data. For this example, I have extracted some actual data from sg.gov related to the mean monthly earnings of graduates from the various universities in Singapore.
Data
Here are the links to my data files:
Mean basic monthly earnings by graduates | |
---|---|
All | sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv |
NUS Only | sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv |
A quick helper function
I will need to read the data from these files several times. So, I will create a function called det_plot_data()
that I can call. You must examine the file structure to understand the data and why I am skipping the first line.
def get_plot_data():
= {}
data = 'sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv'
filename 'All'] = np.loadtxt(filename, skiprows=1)
data[
= 'sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv'
filename 'NUS'] = np.loadtxt(filename, skiprows=1)
data[
return data
sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv
basic_monthly_mean
3701
2850
3053
3557
3494
2952
3235
3326
3091
sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv
basic_monthly_mean
2741
3057
3098
2960
3404
2740
3065
3350
3933
The histogram
'bmh')
plt.style.use(= get_plot_data()
data
# bins specifies how many bins to split the data
'All'], data['NUS']], bins=50, label=['All', 'NUS'])
plt.hist([data['Mean of Basic Montly Earning (S$)')
plt.xlabel('Number of Students')
plt.ylabel( plt.legend()
6.2 Scatter plots
Scatter plots are created by putting a marker at an \((x,y)\) point you specify. They are simple yet powerful.
I will be lazy and use the same data as the previous example. But, since I need some values for \(x\) I am going to use range()
along with len()
to generate a list [0,1,2...]
appropriate to the dataset.
"seaborn-v0_8-darkgrid")
plt.style.use(
= get_plot_data()
data
for label, numbers in data.items():
= range(len(numbers))
x = numbers
y =label, alpha=.5)
plt.scatter(x, y, label
'Position in the list')
plt.xlabel('Mean of Basic Montly Eraning (S$)')
plt.ylabel( plt.legend()
6.3 Bar charts
I am using some dummy data for a hypothetical class for this example. I extract the data and typecast to pass two lists to bar()
. Use barh()
if you want horizontal bars.
= {'Life Sciences': 14,
student_numbers 'Physics': 12,
'Chemistry': 8,
'Comp. Biology': 1}
= list(student_numbers.keys())
majors = list(student_numbers.values())
numbers
'ggplot')
plt.style.use(
plt.bar(majors, numbers)'Majors')
plt.xlabel('Number of Students') plt.ylabel(
6.4 Pie charts
I am not a big fan of pie charts, but they have their uses. Let me reuse the previous data from the dummy class.
= {'Life Sciences': 14,
student_numbers 'Physics': 12,
'Chemistry': 8,
'Comp. Biology': 1}
= list(student_numbers.keys())
majors = list(student_numbers.values())
numbers
'fivethirtyeight')
plt.style.use(
plt.pie(numbers, =majors,
labels='%1.1f%%', # How to format the percentages
autopct=-90
startangle
)'Percentage of each major') plt.title(