Python-Matplotlib視覺化(10)——一文詳解3D統計圖的繪製

2021-07-12 22:00:21

Python-Matplotlib視覺化(10)——一文詳解3D統計圖的繪製

前言

Matplotlib 是 Python 的繪相簿,它提供了一整套和 matlab 相似的命令 API,可以生成你所需的出版品質級別的圖形,而製作3D圖形的API與2D API非常相似。我們已經學習了一系列2D統計圖的繪製,而在統計圖中再新增一個維度可以展示更多資訊。而且,在進行常規彙報或演講時,3D圖形也可以吸引更多的注意力。在本系列的最後一篇中,我們將探討利用 Matplotlib 繪製三維統計圖。

3D散點圖

3D散點圖的繪製方式與2D散點圖基本相同。

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# Dataset generation
a, b, c = 10., 28., 8. / 3.
def lorenz_map(x, dt = 1e-2):
    x_dt = np.array([a * (x[1] - x[0]), x[0] * (b - x[2]) - x[1], x[0] * x[1] - c * x[2]])
    return x + dt * x_dt
points = np.zeros((2000, 3))
x = np.array([.1, .0, .0])
for i in range(points.shape[0]):
    points[i], x = x, lorenz_map(x)
# Plotting
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
ax.set_title('Lorenz Attractor a=%0.2f b=%0.2f c=%0.2f' % (a, b, c))
ax.scatter(points[:, 0], points[:, 1],points[:, 2], zdir = 'z', c = 'c')
plt.show()

3D散點圖

Tips:按住滑鼠左鍵移動滑鼠可以旋轉檢視三維圖形將旋轉。

為了使用 Matplotlib 進行三維操作,我們首先需要匯入 Matplotlib 的三維擴充套件:

from mpl_toolkits.mplot3d import Axes3D

對於三維繪圖,需要建立一個Figure範例並附加一個 Axes3D 範例:

fig = plt.figure()
ax = fig.gca(projection='3d')

之後,三維散點圖的繪製方式與二維散點圖完全相同:

ax.scatter(points[:, 0], points[:, 1],points[:, 2], zdir = 'z', c = 'c')

Tips:需要呼叫 Axes3D 範例的 scatter() 方法,而非plt中的 scatter 方法。只有 Axes3D 中的 scatter() 方法才能解釋三維資料。同時2D統計圖中的註釋也可以在3D圖中使用,例如 set_title()、set_xlabel()、set_ylabel() 和 set_zlabel() 等。
同時可以通過使用 Axes3D.scatter() 的可選引數更改統計通的形狀和顏色:

ax.scatter(points[:, 0], points[:, 1],points[:, 2], zdir = 'z', c = 'c', marker='s', edgecolor='0.5', facecolor='m')

修改樣式

3D曲線圖

與在3D空間中繪製散點圖類似,繪製3D曲線圖同樣需要設定一個 Axes3D 範例,然後呼叫其plot()方法:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# Dataset generation
a, b, c = 10., 28., 8. / 3.
def lorenz_map(x, dt = 1e-2):
    x_dt = np.array([a * (x[1] - x[0]), x[0] * (b - x[2]) - x[1], x[0] * x[1] - c * x[2]])
    return x + dt * x_dt
points = np.zeros((8000, 3))
x = np.array([.1, .0, .0])
for i in range(points.shape[0]):
    points[i], x = x, lorenz_map(x)
# Plotting
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
ax.set_title('Lorenz Attractor a=%0.2f b=%0.2f c=%0.2f' % (a, b, c))
ax.plot(points[:, 0], points[:, 1], points[:, 2], c = 'c')
plt.show()

3D曲線圖

3D標量場

到目前為止,我們看到的3D繪圖方式類似與相應的2D繪圖方式,但也有許多特有的三維繪圖功能,例如將二維標量場繪製為3D曲面:

import numpy as np
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
x = np.linspace(-3, 3, 256)
y = np.linspace(-3, 3, 256)
x_grid, y_grid = np.meshgrid(x, y)
z = np.sinc(np.sqrt(x_grid ** 2 + y_grid ** 2))
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.plot_surface(x_grid, y_grid, z, cmap=cm.viridis)
plt.show()

3D標量場Tips: plot_surface() 方法使用 x、y 和 z 將標量場顯示為三維曲面。
可以看到曲面上線條帶有顯著色彩,如果不希望看到三維曲面上顯示的曲線色彩,可以使用 plot_surface() 附加可選引數:

ax.plot_surface(x_grid, y_grid, z, cmap=cm.viridis, linewidth=0, antialiased=False)

3D標量場同樣,我們也可以僅保持曲線色彩,而曲面不使用其他顏色,這也可以通過 plot_surface() 的可選引數來完成:

import numpy as np
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
x = np.linspace(-3, 3, 256)
y = np.linspace(-3, 3, 256)
x_grid, y_grid = np.meshgrid(x, y)
z = np.sinc(np.sqrt(x_grid ** 2 + y_grid ** 2))
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.plot_surface(x_grid, y_grid, z, edgecolor='b',color='w')
plt.show()

3D標量場
而如果我們希望消除曲面,而僅使用線框進行繪製,這可以使用 plot_wireframe() 函數:

ax.plot_wireframe(x_grid, y_grid, z, cstride=10, rstride=10,color='c')

3D標量場Tips:plot_wireframe() 引數與 plot_surface() 相同,使用兩個可選引數 rstride 和 cstride 用於令 Matplotlib 跳過x和y軸上指定數量的座標,用於減少曲線的密度。

繪製3D曲面

在前述方法中,使用 plot_surface() 來繪製標量:即 f(x, y)=z 形式的函數,但 Matplotlib 也能夠使用更通用的方式繪製三維曲面:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# Generate torus mesh
angle = np.linspace(0, 2 * np.pi, 32)
theta, phi = np.meshgrid(angle, angle)
r, r_w = .25, 1.
x = (r_w + r * np.cos(phi)) * np.cos(theta)
y = (r_w + r * np.cos(phi)) * np.sin(theta)
z = r * np.sin(phi)
# Display the mesh
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.set_xlim3d(-1, 1)
ax.set_ylim3d(-1, 1)
ax.set_zlim3d(-1, 1)
ax.plot_surface(x, y, z, color = 'c', edgecolor='m', rstride = 2, cstride = 2)
plt.show()

繪製3D曲面
同樣可以使用 plot_wireframe() 替換對 plot_surface() 的呼叫,以便獲得圓環的線框檢視:

ax.plot_wireframe(x, y, z, edgecolor='c', rstride = 2, cstride = 1)

繪製3D曲面

在3D座標軸中繪製2D圖形

註釋三維圖形的一種有效方法是使用二維圖形:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
x = np.linspace(-3, 3, 256)
y = np.linspace(-3, 3, 256)
x_grid, y_grid = np.meshgrid(x, y)
z = np.exp(-(x_grid ** 2 + y_grid ** 2))
u = np.exp(-(x ** 2))
fig = plt.figure()
ax = fig.gca(projection = '3d')
ax.set_zlim3d(0, 3)
ax.plot(x, u, zs=3, zdir='y', lw = 2, color = 'm')
ax.plot(x, u, zs=-3, zdir='x', lw = 2., color = 'c')
ax.plot_surface(x_grid, y_grid, z, color = 'b')
plt.show()

註釋三維圖形Axes3D 範例同樣支援常用的二維渲染命令,如plot():

ax.plot(x, u, zs=3, zdir='y', lw = 2, color = 'm')

Axes3D 範例對 plot() 的呼叫有兩個新的可選引數:
zdir :用於決定在哪個平面上繪製2D繪圖,可選值包括 x、y 或 z ;
zs :用於決定平面的偏移。
因此,要將二維圖形嵌入到三維圖形中,只需將二維原語用於 Axes3D 範例,同時使用可選引數,zdirzs,來放置所需渲染圖形平面。
接下來,讓我們實際檢視下在3D空間中堆疊2D條形圖的範例:

import numpy as np
from matplotlib import cm
import matplotlib.colors as col
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# Data generation
alpha = 1. / np.linspace(1, 8, 5)
t = np.linspace(0, 5, 16)
t_grid, a_grid = np.meshgrid(t, alpha)
data = np.exp(-t_grid * a_grid)
# Plotting
fig = plt.figure()
ax = fig.gca(projection = '3d')
cmap = cm.ScalarMappable(col.Normalize(0, len(alpha)), cm.viridis)
for i, row in enumerate(data):
    ax.bar(4 * t, row, zs=i, zdir='y', alpha=0.8, color=cmap.to_rgba(i))
plt.show()

堆疊2D圖形

3D柱形圖

# plt.show()
import numpy as np
from matplotlib import cm
import matplotlib.colors as col
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# Data generation
alpha = np.linspace(1, 8, 5)
t = np.linspace(0, 5, 16)
t_grid, a_grid = np.meshgrid(t, alpha)
data = np.exp(-t_grid * (1. / a_grid))
# Plotting
fig = plt.figure()
ax = fig.gca(projection = '3d')
xi = t_grid.flatten()
yi = a_grid.flatten()
zi = np.zeros(data.size)
dx = .30 * np.ones(data.size)
dy = .30 * np.ones(data.size)
dz = data.flatten()
ax.set_xlabel('T')
ax.set_ylabel('Alpha')
ax.bar3d(xi, yi, zi, dx, dy, dz, color = 'c')
plt.show()

3D柱形圖

3D柱體以網格佈局定位,bar3d() 方法接受六個必需引數作為輸入。前三個引數是每個柱體下端的x、y和z座標:

xi = t_grid.flatten()
yi = a_grid.flatten()
zi = np.zeros(data.size)

Tips:bar3d() 方法將座標列表作為輸入,而不是網格座標,因此需要對矩陣 a_grid 和 t_grid 呼叫flatten方法。
bar3d() 方法的另外三個必需引數是每個柱體在每個維度的值。這裡,條形圖的高度取自資料矩陣。條形寬度和深度設定為.30:

dx = .30 * np.ones(data.size)
dy = .30 * np.ones(data.size)
dz = data.flatten()

系列連結

Python-Matplotlib視覺化(1)——一文詳解常見統計圖的繪製
Python-Matplotlib視覺化(2)——自定義顏色繪製精美統計圖
Python-Matplotlib視覺化(3)——自定義樣式繪製精美統計圖
Python-Matplotlib視覺化(4)——新增註釋讓統計圖通俗易懂
Python-Matplotlib視覺化(5)——新增自定義形狀繪製複雜圖形
Python-Matplotlib視覺化(6)——自定義座標軸讓統計圖清晰易懂
Python-Matplotlib視覺化(7)——多方面自定義統計圖繪製
Python-Matplotlib視覺化(8)——圖形的輸出與儲存
Python-Matplotlib視覺化(9)——精通更多實用圖形的繪製