Making better plots with matplotlib.pyplot in Python3

The default plots made by Python’s matplotlib.pyplot module are almost always insufficient for publication. With a ~20 extra lines of code, however, you can generate high-quality plots suitable for inclusion in your next article.

Let’s start with code for a very default plot:

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(1)
d1 = np.random.normal(1.0, 0.1, 1000)
d2 = np.random.normal(3.0, 0.1, 1000)
xvals = np.arange(1, 1000+1, 1)

plt.plot(xvals, d1, label='data1')
plt.plot(xvals, d2, label='data2')
plt.legend(loc='best')
plt.xlabel('Time, ns')
plt.ylabel('RMSD, Angstroms')
plt.savefig('bad.png', dpi=300)

The result of this will be:

Plot generated with matplotlib.pyplot defaults

The fake data I generated for the plot look something like Root Mean Square Deviation (RMSD) versus time for a converged molecular dynamics simulation, so let’s pretend they are. There are a number of problems with this plot: it’s overall ugly, the color scheme is not very attractive and may not be color-blind friendly, the y-axis range of the data extends outside the range of the tick labels, etc.

We can easily convert this to a much better plot:

Updated plot with better color scheme, more legible labels, and better use of white space

We just need to add some extra lines to the code as follows:

import matplotlib.pyplot as plt
import numpy as np#np.random.seed(1)

d1 = np.random.normal(1.0, 0.1, 1000)
d2 = np.random.normal(3.0, 0.1, 1000)xvals = np.arange(1, 1000+1, 1)

fontsize = 12 # set variable for fontsize
linewidth = 2 # set variable for line width
colors = ['#4477AA', '#AA3377'] # from "bright" color-blind friendly colors from Paul Tol"s notes

ax = plt.subplot(111) # get axes object for subplot
ax.spines['right'].set_visible(False) # remove right plot boundary
ax.spines['top'].set_visible(False) # remove top plot boundary
ax.spines['left'].set_linewidth(linewidth) # make left axis thicker
ax.spines['bottom'].set_linewidth(linewidth) # make bottom axis thicker
ax.xaxis.set_tick_params(width=linewidth) # make x-axis tick marks thicker
ax.yaxis.set_tick_params(width=linewidth) # make y-axis tick marks thicker

plt.plot(xvals, d1, label='data1', color=colors[0]) # use better colors
plt.plot(xvals, d2, label='data2', color=colors[1]) # use better colors

plt.xlabel('Time, ns', fontsize=fontsize+2) # update fontsize
plt.ylabel('RMSD, '+r'$\AA$', fontsize=fontsize+2) # use Angstrom symbol in axis label, update fontsize
plt.yticks([0.0, 1.0, 2.0, 3.0, 4.0], fontsize=fontsize) # make sure tick labels span the y range; update fontsize
plt.xticks(fontsize=fontsize) # update xtick fontsize

plt.legend(loc='upper center', ncol=2, fancybox=True, framealpha=1.0) # put legend in convenient location and make 2 columns

plt.savefig('better.png', dpi=300)

And that’s it – you can read through the comments to see what each line of code is doing. One last thing: take a look at Paul Tol’s excellent notes on color theory and how to choose colors that are maximally distinct from one another and work for the color blind as well.

Author