Matplotlib subplots and axes objects¶
Subplots¶
The subplot
function of the matplotlib
module is a tool for
plotting several graphs on a single figure. By calling
subplot(n,m,k)
we subdidive the figure into n
rows and m
columns and specify that plotting should be done on the subplot number
k
. Subplots are numbered row by row, from left to right.
import matplotlib.pyplot as plt
import numpy as np
from math import pi
plt.figure(figsize=(8,4)) # set dimensions of the figure
x = np.linspace(0,2*pi, 100)
for i in range(1,7):
plt.subplot(2,3, i) # create subplots on a grid with 2 rows and 3 columns
plt.xticks([]) # set no ticks on x-axis
plt.yticks([]) # set no ticks on y-axis
plt.plot(np.sin(x), np.cos(i*x))
plt.title('subplot' + '(2,3,' + str(i) + ')')
plt.show()
Note. If the numbers i
, j
, k
are all smaller than 10 we
can specify a subplot by typing subplot(ijk)
instead of
subplot(i,j,k)
:
plt.figure(figsize=(8,4))
x = np.linspace(0,2*pi, 100)
plt.subplot(231)
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(x), np.cos(x))
plt.title('subplot(231)')
plt.subplot(233)
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(x), np.cos(3*x))
plt.title('subplot(233)')
plt.subplot(235)
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(x), np.cos(5*x))
plt.title('subplot(235)')
plt.show()
It is possible to combine subplots of different sizes as long as they do not overlap:
plt.figure(figsize=(8,4))
x = np.linspace(0,2*pi, 200)
for i in [1, 2, 4, 5]:
plt.subplot(2,3,i) # create some subplots on a grid with 2 rows and 3 columns
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(3*x), np.cos(i*x))
plt.title('subplot(2,3,' + str(i) + ')')
plt.subplot(1,3,3) # create a subplot on a grid with 1 row and 3 columns
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(10*x), x)
plt.title('subplot(1,3,3)')
plt.show()
Spacing between subplots can be controlled using the subplots_adjust
function:
plt.figure(figsize=(8,4))
x = np.linspace(0,2*pi, 200)
plt.subplots_adjust(wspace=0.05, # wspace controls the width of space between subplots
hspace=0.5) # hspace controls the hight of space between subplots
for i in [1, 2, 4, 5]:
plt.subplot(2,3,i)
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(3*x), np.cos(i*x))
plt.title('subplot(2,3,' + str(i) + ')')
plt.subplot(1,3,3)
plt.xticks([])
plt.yticks([])
plt.plot(np.sin(10*x), x)
plt.title('subplot(1,3,3)')
plt.show()
Axes objects¶
The subplot
function returns an axes
object. We can use it to
specify which subplot is active at any time:
plt.figure(figsize=(8,4))
x = np.linspace(0,2*pi, 200)
plt.subplots_adjust(hspace=0.4)
ax1 = plt.subplot(2,1,1) # subplot(2,1,1) is active, plotting will be done there
plt.xlim(0, 2*pi)
plt.plot(x, np.sin(2*x))
plt.title('subplot(2,1,1)')
ax2 = plt.subplot(2,1,2) # subplot(2,1,2) is now active
plt.xlim(0, 2*pi)
plt.plot(x, np.sin(10*x), 'g')
plt.title('subplot(2,1,2)')
plt.axes(ax1) # we activate subplot(2,1,1) to do more plotting on this subplot
plt.plot(x, np.cos(2*x), 'r--')
plt.show()
The axes
function that we used above to select an existing axes
object can be also used to create such objects. This is a useful
alternative to the subplot
function since it gives more flexibility
in setting the layout of the figure: while the subplot
function
creates an evenly spaced grid, using the axes
function we can place
graphs within the figure any way we want:
plt.figure(figsize=(8,4))
# we use the axes function to create an axes object
# coordinates of the object within the picture are numbers between 0 and 1.
# The point (0,0) is the lower left corner of the figure, the point (1,1)
# is the upper right corner
ax1 = plt.axes([
0.1, # x-coordinate of the lower left corner of the axes object
0.1, # y-coordinate of the lower left corner of the axes object
0.5, # width of the object
0.4 # height of the object
])
#here we create another axes object
ax2 = plt.axes([0.5, 0.2, 0.4, 0.6])
x = np.linspace(0,2*pi, 300)
plt.axes(ax1) # select ax1 to do some plotting there
plt.title('This is ax1')
plt.xlim(0, 2*pi)
plt.plot(x, np.cos(20*x), 'g')
plt.xticks([])
plt.yticks([])
plt.axes(ax2) # switch to ax2
plt.title('This is ax2')
plt.xlim(0, 2*pi)
plt.plot(x, np.sin(2*x))
plt.plot(x, np.cos(2*x), 'r--')
plt.xticks([])
plt.yticks([])
plt.show()