Plotting (Good)

The various parts of a Matplotlib figure. (From matplotlib.org)

What to expect in this chapter

In the previous chapter, we saw how you could use Matplotlib to produce simple, decent-looking plots. However, we haven’t really (barely) tapped the full power of what Matplotlib can do. For this, I need to introduce you to a different way of speaking to Matplotlib. So far, the ‘dialect’ we have used to talk to Matplotlib is called the Matlab-like pyplot(plt) interface. From here onward, I will show you how to use the other, more powerful ‘dialect’ called the Object Oriented (OO) interface. This way of talking to Matplotlib gives us a more nuanced control over what is going on by allowing us to manipulate the various axes easily.

1 Some nomenclature

Before going ahead, let’s distinguish between a Matplotlib figure and an axis.

A figure is simple; it is the full canvas you use to draw stuff on. An axis is the individual mathematical axes we use for plotting. So, one figure can have multiple axes, as shown below, where we have a (single) figure with four axes.

By the way, you had already encountered a situation with multiple axes in the last chapter when we used twinx(). It is not uncommon to struggle with the concept of axes; but don’t worry; it will become clearer as we work through this chapter.

2 Comparing the two ‘dialects’

Let me create the same plot using both pyplot and OO ‘dialects’ so that you can see how the latter works. The plot I will be creating is shown below.

First, let’s generate some data to plot.

x = np.linspace(-np.pi, np.pi, num=100)
cos_x = np.cos(x)
sin_x = np.sin(x)

Now the comparison; remember that both sets of code will produce the same plot.

plt.plot(x, cos_x, label='cos x')
plt.plot(x, sin_x, label='sin x')
plt.legend()
fig, ax = plt.subplots(nrows=1, ncols=1)
ax.plot(x, cos_x, label='cos x')
ax.plot(x, sin_x, label='sin x')
ax.legend()

pyplot Interface

OO Interface

For the OO interface, we have to start by using subplots() to ask Matplotlib to create a figure and an axis. Matplotlib obliges and gives us a figure (fig) and an axis (ax).

Although I have used the variables fig and ax you are free to call them what you like. But this is what is commonly used in the documentation. In this example, I need only one column and one row. But, if I want, I can ask for a grid like in the plot right at the top.

Yes, the OO looks more complicated than the pyplot version. But, it offers so much freedom that it is worth learning it for more demanding, complex plots. You will see this soon.

Remember

Remember to use the pyplot interface for quick and dirty plots and the OO interface for more complex plots that demand control and finesse.

3 What is OO ax, really?

The code below creates the (crude) plot shown on the right. Let’s look at what is happening so that we can understand what is going on.

fig, ax = plt.subplots(nrows=2, ncols=1)

ax[0].plot(x, cos_x, label='cos x')
ax[1].plot(x, sin_x, label='sin x')

ax[0].legend()
ax[1].legend()

To get the above plot, we must ask for two rows (nrows=2) and one column (ncols=1). We do this by using subplots()

fig, ax = plt.subplots(ncols=1, nrows=2)

This should give me two axes so that I can plot in both panes. Let’s quickly check a few more details about ax.

  • What is ax?

    type(ax)
    <class 'numpy.ndarray'>

    So ax is a NumPy array!

  • What size is ax?

    ax.shape
    (2,)

    As expected, ax has two ‘things’.

  • What is contained in ax?

    type(ax[0])
    <class 'matplotlib.axes._axes.Axes'>

    This is a Matplotlib axis.

4 A complete OO example

The following is a simple example that creates a nicer, tweaked version of the previous plot.

fig, ax = plt.subplots(nrows=2, ncols=1,
                       figsize=(5, 5),
                       sharex=True)
ax[0].plot(x, cos_x, label=r'$\cos(x)$')
ax[0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].plot(x, sin_x, label=r'$\sin(x)$')
ax[1].fill_between(x, 0, sin_x, alpha=.25)

for a in ax:
    a.legend()
    a.grid(alpha=.25)
    a.set_ylabel('$y$', rotation=0)

ax[1].set_xlabel('$x$')

fig.suptitle(r'$\sin(x)$ and $\cos(x)$')
fig.tight_layout()

Now, let me take you this code and try to explain what is happening.

fig, ax = plt.subplots(nrows=2, ncols=1,
                       figsize=(5, 5),
                       sharex=True)


ax[0].plot(x, cos_x, label=r'$\cos(x)$')
ax[0].fill_between(x, 0, cos_x, alpha=.25)
ax[1].plot(x, sin_x, label=r'$\sin(x)$')
ax[1].fill_between(x, 0, sin_x, alpha=.25)

for a in ax:
    a.legend()
    a.grid(alpha=.25)
    a.set_ylabel('$y$', rotation=0)


ax[1].set_xlabel('$x$')

fig.suptitle(r'$\sin(x)$ and $\cos(x)$')

fig.tight_layout()
  • Set up the axes and figure
    • Get two axes ready to plot in two rows in one column.
    • Change the size of the figure by specifying a figure size (figsize).
    • Ask that the plots share the \(x\)-axis using sharex.
  • Since ax is a NumPy array with two axes, we can index them using 0 and 1. Then we just plot to that axis.
  • Use fill_between() to fill the plots.
    • I have again just indexed the NumPy array to access the axes.
  • Draw the legends
    • As in the previous example, we can do this one axis at a time. However, a more sensible way to do this is with a for loop that iterates through the items in ax
    • Let’s also add a grid to each plot and set the label.
  • We are sharing the \(x\)-axis. So, we only need to label the lowest plot. This has the index 1.
  • Let’s add a super title to the figure (notplot).
  • Finally, let’s ask Matplotlib to make any necessary adjustments to the layout to make our plot look nice by calling tight_layout(). It would help if you convinced yourself of the utility of tight_layout() by producing the plot with and without it.
Notice

Unfortunately, the pyplot and OO interfaces don’t use identical function names. For example, pyplot used xlabel() to set the \(x\) label, but OO uses set_xlabel(). Annoying!

5 Multiple rows and columns

Now, I will show you how to work with multiple rows and columns that form a grid of plots like the one on the left. However, let’s start by using the figure on the right.

I have intentionally kept it simple (only setting the title and plotting some lines) so that we can identify how ax works.

fig, ax = plt.subplots(nrows=2, ncols=2,
                       figsize=(5, 5),
                       sharex='col', sharey='col')

# Some variables to access the axes and improve readabilty
top_left, top_right, bottom_left, bottom_right = ax.flatten()

top_left.set_title('Top Left')
top_right.set_title('Top Right')
bottom_left.set_title('Bottom Left')
bottom_right.set_title('Bottom Right')

top_left.hlines(y=1, xmin=0, xmax=4)
top_right.hlines(y=2, xmin=0, xmax=5)
bottom_left.hlines(y=3, xmin=0, xmax=4)
bottom_right.hlines(y=4, xmin=0, xmax=5)

for a in ax.flatten():
    a.grid(alpha=.25)

plt.tight_layout()

Using ax

I create a figure and axes using:

fig, ax = plt.subplots(nrows=2, ncols=2,
                       figsize=(5, 5),
                       sharex='col', sharey='row')

The most important thing you must understand is how to use ax.

We know there must be four axes; but how is ax structured? Let’s look at its shape.

ax.shape
(2, 2)

So, ax is organised as you see in the figure, as a 2 x 2 array. So, I can access each of the axes as follows:

ax[0, 0].set_title('Top Left')
ax[0, 1].set_title('Top Right')
ax[1, 0].set_title('Bottom Left')
ax[1, 1].set_title('Bottom Right')

This is a perfectly valid way to use ax. However, when you have to tweak each axis separately, I find it easy to use a familiar variable. I can do this by:

top_left=ax[0, 0]
top_right=ax[0, 1]
bottom_left=ax[1, 0]
bottom_right=ax[1, 1]

You can also use:

top_left, top_right, bottom_left, bottom_right = ax.flatten()

flatten() takes the 2D array and ‘flattens’(dah) it into a 1D array; unpacking takes care of the assignments.

Oh, I forgot to show you how I drew the lines. I used the very useful Matplotlib function hlines():

top_left.hlines(y=1, xmin=0, xmax=5)
top_right.hlines(y=2, xmin=0, xmax=5)
bottom_left.hlines(y=3, xmin=0, xmax=5)
bottom_right.hlines(y=4, xmin=0, xmax=5)

5.1 Sharing axes

Notice I have asked Matplotlib to make the plots more compact by sharing the \(x\) and \(y\) axes using sharex and sharey.

Let’s first see what happens if I do not specify how to share.

fig, ax = plt.subplots(
    nrows=2, ncols=2,
    figsize=(5, 5)
)

You see that Matplotlib has auto-scaled both axes. In particular, the plots on the left go from 0 to 4, and those on the right go from 0 to 5, as these are the lengths I used for the lines.

Now, let me specify how to share the axes. I can do this in three ways:

Option Result
True Makes all the axes use the same range.
col Use the same range for all the columns
row Use the same range for all the rows

Let’s try the following:

fig, ax = plt.subplots(
    nrows=2, ncols=2,
    figsize=(5, 5),
    sharex=True, sharey='row'
)

Notice how all the plots have the same range for the \(x\)-axis.

However, sharex='col' is more suited for the data we are plotting, so let’s use that instead.

fig, ax = plt.subplots(
    nrows=2, ncols=2,
    figsize=(5, 5),
    sharex='col', sharey='row'
)

By the way, how you decide on the most correct depends on what story you are trying to communicate with your plot.

5.2 Accessing all axes

You will often want to apply changes to all the axes, like in the case of the grid. You can do this by

top_left.grid(alpha=.25)
top_right.grid(alpha=.25)
bottom_left.grid(alpha=.25)
bottom_right.grid(alpha=.25)

But this is inefficient and requires a lot of work. It is much nicer to use a for loop.

for a in ax.flatten():
    a.grid(alpha=.25)

6 Other useful plots

In this section, I will quickly show you some useful plots we can generate with Matplotlib. I will also use a few different plotting styles I commonly use so that you can get a feel for how to change styles.

6.1 Histograms

A histogram is a valuable tool for showing distributions of data. For this example, I have extracted some actual data from sg.gov related to the mean monthly earnings of graduates from the various universities in Singapore.

Data

Here are the links to my data files:

A quick helper function

I will need to read the data from these files several times. So, I will create a function called det_plot_data() that I can call. You must examine the file structure to understand the data and why I am skipping the first line.

def get_plot_data():
    data = {}
    filename = 'sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv'
    data['All'] = np.loadtxt(filename, skiprows=1)

    filename = 'sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv'
    data['NUS'] = np.loadtxt(filename, skiprows=1)

    return data

sg-gov-graduate-employment-survey_basic_monthly_mean_all.csv

basic_monthly_mean
3701
2850
3053
3557
3494
2952
3235
3326
3091

sg-gov-graduate-employment-survey_basic_monthly_mean_nus.csv

basic_monthly_mean
2741
3057
3098
2960
3404
2740
3065
3350
3933

The histogram

plt.style.use('bmh')
data = get_plot_data()

# bins specifies how many bins to split the data
plt.hist([data['All'], data['NUS']], bins=50, label=['All', 'NUS'])
plt.xlabel('Mean of Basic Montly Earning (S$)')
plt.ylabel('Number of Students')
plt.legend()

6.2 Scatter plots

Scatter plots are created by putting a marker at an \((x,y)\) point you specify. They are simple yet powerful.

I will be lazy and use the same data as the previous example. But, since I need some values for \(x\) I am going to use range() along with len() to generate a list [0,1,2...] appropriate to the dataset.

plt.style.use("seaborn-v0_8-darkgrid")

data = get_plot_data()

for label, numbers in data.items():
    x = range(len(numbers))
    y = numbers
    plt.scatter(x, y, label=label, alpha=.5)

plt.xlabel('Position in the list')
plt.ylabel('Mean of Basic Montly Eraning (S$)')
plt.legend()

6.3 Bar charts

I am using some dummy data for a hypothetical class for this example. I extract the data and typecast to pass two lists to bar(). Use barh() if you want horizontal bars.

student_numbers = {'Life Sciences': 14,
                   'Physics': 12,
                   'Chemistry': 8,
                   'Comp. Biology': 1}
majors = list(student_numbers.keys())
numbers = list(student_numbers.values())

plt.style.use('ggplot')
plt.bar(majors, numbers)
plt.xlabel('Majors')
plt.ylabel('Number of Students')

6.4 Pie charts

I am not a big fan of pie charts, but they have their uses. Let me reuse the previous data from the dummy class.

student_numbers = {'Life Sciences': 14,
                   'Physics': 12,
                   'Chemistry': 8,
                   'Comp. Biology': 1}
majors = list(student_numbers.keys())
numbers = list(student_numbers.values())

plt.style.use('fivethirtyeight')
plt.pie(numbers, 
        labels=majors,
        autopct='%1.1f%%',   # How to format the percentages
        startangle=-90                
        )
plt.title('Percentage of each major')
Back to top