線性判別模型(LDA)在圖形識別領域(比如臉部辨識等圖形影象識別領域)中有非常廣泛的應用。LDA是一種監督學習的降維技術,也就是說它的資料集的每個樣本是有類別輸出的。這點和PCA不同。PCA是不考慮樣本類別輸出的無監督降維技術。LDA的思想可以用一句話概括,就是「投影后類內方差最小,類間方差最大」。我們要將資料在低維度上進行投影,投影后希望每一種類別資料的投影點儘可能的接近,而不同類別的資料的類別中心之間的距離儘可能的大。即:將資料投影到維度更低的空間中,使得投影后的點,會形成按類別區分,一簇一簇的情況,相同類別的點,將會在投影后的空間中更接近方法。
LDA演演算法的一個目標是使得不同類別之間的距離越遠越好,同一類別之中的距離越近越好。那麼不同類別之間的距離越遠越好,我們是可以理解的,就是越遠越好區分。同時,協方差不僅是反映了變數之間的相關性,同樣反映了多維樣本分佈的離散程度(一維樣本使用方差),協方差越大(對於負相關來說是絕對值越大),表示資料的分佈越分散。所以上面的「欲使同類樣例的投影點儘可能接近,可以讓同類樣本點的協方差矩陣儘可能小」就可以理解了。
$J(w)=\frac{w^T|\mu_1 - \mu_2~|2}{s2_1+s2_2}$
如上述公式 $J(w)$ 所示,分子為投影資料後的均值只差,分母為方差之後,LDA的目的就是使得 $J$ 值最大化,那麼可以理解為最大化分子,即使得類別之間的距離越遠,同時最小化分母,使得每個類別內部的方差越小,這樣就能使得每個類類別的資料可以在投影矩陣 $w$ 的對映下,分的越開。
需要注意的是,LDA模型適用於線性可分資料,對於上述實戰中用到的MNIST手寫資料(其實是分線性的),但是依然可以取得較好的分類效果;但在以後的實戰中需要注意LDA在非線性可分資料上的謹慎使用。
LDA在圖形識別領域(比如臉部辨識,艦艇識別等圖形影象識別領域)中有非常廣泛的應用,因此我們有必要了解一下它的演演算法原理。不過在學習LDA之前,我們有必要將其與自然語言處理領域中的LDA區分開,在自然語言處理領域,LDA是隱含狄利克雷分佈(Latent DIrichlet Allocation,簡稱LDA),它是一種處理檔案的主題模型,我們本文討論的是線性判別分析,因此後面所說的LDA均為線性判別分析。
LDA除了可以用於降維以外,還可以用於分類。一個常見的LDA分類基本思想是假設各個類別的樣本資料符合高斯分佈,這樣利用LDA進行投影后,可以利用極大似然估計計算各個類別投影資料的均值和方差,進而得到該類別高斯分佈的概率密度函數。當一個新的樣本到來後,我們可以將它投影,然後將投影后的樣本特徵分別帶入各個類別的高斯分佈概率密度函數,計算它屬於這個類別的概率,最大的概率對應的類別即為預測類別。
Part 1 Demo實踐
Part 2 基於LDA手寫數位分類實踐
# 基礎陣列運算庫匯入
import numpy as np
# 畫相簿匯入
import matplotlib.pyplot as plt
# 匯入三維顯示工具
from mpl_toolkits.mplot3d import Axes3D
# 匯入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 匯入demo資料製作方法
from sklearn.datasets import make_classification
# 製作四個類別的資料,每個類別100個樣本
X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0,
n_classes=4, n_informative=2, n_clusters_per_class=1,
class_sep=3, random_state=10)
# 將四個類別的資料進行三維顯示
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y)
plt.show()
# 建立 LDA 模型
lda = LinearDiscriminantAnalysis()
# 進行模型訓練
lda.fit(X, y)
LinearDiscriminantAnalysis()
# 檢視 LDA 模型的引數
lda.get_params()
{'covariance_estimator': None,
'n_components': None,
'priors': None,
'shrinkage': None,
'solver': 'svd',
'store_covariance': False,
'tol': 0.0001}
# 進行模型預測
X_new = lda.transform(X)
# 視覺化預測資料
plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y)
plt.show()
# 進行新的測試資料測試
a = np.array([[-1, 0.1, 0.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
a = np.array([[-12, -100, -91]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
a = np.array([[-12, -0.1, -0.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
a = np.array([[0.1, 90.1, 9.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
[[-1. 0.1 0.1]] 類別是: [0]
[[-1. 0.1 0.1]] 類別概率分別是: [[9.37611354e-01 1.88760664e-05 3.36891510e-02 2.86806189e-02]]
[[ -12 -100 -91]] 類別是: [1]
[[ -12 -100 -91]] 類別概率分別是: [[1.08769337e-028 1.00000000e+000 1.54515810e-221 9.05666876e-183]]
[[-12. -0.1 -0.1]] 類別是: [2]
[[-12. -0.1 -0.1]] 類別概率分別是: [[1.60268201e-07 1.46912978e-39 9.99999840e-01 3.57001075e-28]]
[[ 0.1 90.1 9.1]] 類別是: [3]
[[ 0.1 90.1 9.1]] 類別概率分別是: [[8.42065614e-08 9.45021749e-11 8.63060269e-02 9.13693889e-01]]
# 匯入手寫資料集 MNIST
from sklearn.datasets import load_digits
# 匯入訓練集分割方法
from sklearn.model_selection import train_test_split
# 匯入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 匯入預測指標計算函數和混淆矩陣計算函數
from sklearn.metrics import classification_report, confusion_matrix
# 匯入繪圖包
import seaborn as sns
import matplotlib
# 匯入MNIST資料集
mnist = load_digits()
# 檢視資料集資訊
print('The Mnist dataeset:\n',mnist)
# 分割資料為訓練集和測試集
x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)
The Mnist dataeset:
{'data': array([[ 0., 0., 5., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 10., 0., 0.],
[ 0., 0., 0., ..., 16., 9., 0.],
...,
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 2., ..., 12., 0., 0.],
[ 0., 0., 10., ..., 12., 1., 0.]]), 'target': array([0, 1, 2, ..., 8, 9, 8]), 'frame': None, 'feature_names': ['pixel_0_0', 'pixel_0_1', 'pixel_0_2', 'pixel_0_3', 'pixel_0_4', 'pixel_0_5', 'pixel_0_6', 'pixel_0_7', 'pixel_1_0', 'pixel_1_1', 'pixel_1_2', 'pixel_1_3', 'pixel_1_4', 'pixel_1_5', 'pixel_1_6', 'pixel_1_7', 'pixel_2_0', 'pixel_2_1', 'pixel_2_2', 'pixel_2_3', 'pixel_2_4', 'pixel_2_5', 'pixel_2_6', 'pixel_2_7', 'pixel_3_0', 'pixel_3_1', 'pixel_3_2', 'pixel_3_3', 'pixel_3_4', 'pixel_3_5', 'pixel_3_6', 'pixel_3_7', 'pixel_4_0', 'pixel_4_1', 'pixel_4_2', 'pixel_4_3', 'pixel_4_4', 'pixel_4_5', 'pixel_4_6', 'pixel_4_7', 'pixel_5_0', 'pixel_5_1', 'pixel_5_2', 'pixel_5_3', 'pixel_5_4', 'pixel_5_5', 'pixel_5_6', 'pixel_5_7', 'pixel_6_0', 'pixel_6_1', 'pixel_6_2', 'pixel_6_3', 'pixel_6_4', 'pixel_6_5', 'pixel_6_6', 'pixel_6_7', 'pixel_7_0', 'pixel_7_1', 'pixel_7_2', 'pixel_7_3', 'pixel_7_4', 'pixel_7_5', 'pixel_7_6', 'pixel_7_7'], 'target_names': array([0, 1, 2, 3, 4, 5, 6, 7,
[ 0., 0., 13., ..., 15., 5., 0.],
[ 0., 3., 15., ..., 11., 8., 0.],
...,
[ 0., 4., 11., ..., 12., 7., 0.],
[ 0., 2., 14., ..., 12., 0., 0.],
[ 0., 0., 6., ..., 0., 0., 0.]],
[[ 0., 0., 0., ..., 5., 0., 0.],
[ 0., 0., 0., ..., 9., 0., 0.],
[ 0., 0., 3., ..., 6., 0., 0.],
...,
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 0., ..., 10., 0., 0.]],
[[ 0., 0., 0., ..., 12., 0., 0.],
[ 0., 0., 3., ..., 14., 0., 0.],
[ 0., 0., 8., ..., 16., 0., 0.],
...,
[ 0., 9., 16., ..., 0., 0., 0.],
[ 0., 3., 13., ..., 11., 5., 0.],
[ 0., 0., 0., ..., 16., 9., 0.]],
...,
[[ 0., 0., 1., ..., 1., 0., 0.],
[ 0., 0., 13., ..., 2., 1., 0.],
[ 0., 0., 16., ..., 16., 5., 0.],
...,
[ 0., 0., 16., ..., 15., 0., 0.],
[ 0., 0., 15., ..., 16., 0., 0.],
[ 0., 0., 2., ..., 6., 0., 0.]],
[[ 0., 0., 2., ..., 0., 0., 0.],
[ 0., 0., 14., ..., 15., 1., 0.],
[ 0., 4., 16., ..., 16., 7., 0.],
...,
[ 0., 0., 0., ..., 16., 2., 0.],
[ 0., 0., 4., ..., 16., 2., 0.],
[ 0., 0., 5., ..., 12., 0., 0.]],
[[ 0., 0., 10., ..., 1., 0., 0.],
[ 0., 2., 16., ..., 1., 0., 0.],
[ 0., 0., 15., ..., 15., 0., 0.],
...,
[ 0., 4., 16., ..., 16., 6., 0.],
[ 0., 8., 16., ..., 16., 8., 0.],
[ 0., 1., 8., ..., 12., 1., 0.]]]), 'DESCR': ".. _digits_dataset:\n\nOptical recognition of handwritten digits dataset\n--------------------------------------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 1797\n :Number of Attributes: 64\n :Attribute Information: 8x8 image of integer pixels in the range 0..16.\n :Missing Attribute Values: None\n :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)\n :Date: July; 1998\n\nThis is a copy of the test set of the UCI ML hand-written digits datasets\nhttps://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\n\nThe data set contains images of hand-written digits: 10 classes where\neach class refers to a digit.\n\nPreprocessing programs made available by NIST were used to extract\nnormalized bitmaps of handwritten digits from a preprinted form. From a\ntotal of 43 people, 30 contributed to the training set and different 13\nto the test set. 32x32 bitmaps are divided into
## 輸出範例影象
images = range(0,9)
plt.figure(dpi=100)
for i in images:
plt.subplot(330 + 1 + i)
plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest")
# show the plot
plt.show()
# 建立 LDA 模型
m_lda = LinearDiscriminantAnalysis()
# 進行模型訓練
m_lda.fit(x, y)
LinearDiscriminantAnalysis()
# 進行模型預測
x_new = m_lda.transform(x)
# 視覺化預測資料
plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y)
plt.title('MNIST with LDA Model')
plt.show()
# 進行測試集資料的類別預測
y_test_pred = m_lda.predict(test_x)
print("測試集的真實標籤:\n", test_y)
print("測試集的預測標籤:\n", y_test_pred)
測試集的真實標籤:
[4 0 9 1 4 7 1 5 1 6 6 7 6 1 5 5 4 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
8 4 6 8 8 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 4 4 7 3 5 8 4 3 1 3 8
7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 6 8 2 2 6 4 6 7 3 7 3 9 4
7 0 3 5 8 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 1 0 3 4 3 0 9 3 2 2 7 3 1 6 7 2 8
3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 8 7 4 8 9 5 7 6 9 0 0 4 0 0 4]
測試集的預測標籤:
[4 0 9 1 8 7 1 5 1 6 6 7 6 2 5 5 8 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
8 4 6 8 1 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 8 4 9 3 5 8 4 3 9 3 8
7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 1 8 2 2 6 4 6 7 3 7 3 9 4
7 0 3 5 1 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 9 0 3 4 3 0 8 3 2 2 7 3 1 6 7 2 8
3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 9 7 4 8 9 5 7 6 9 6 0 4 0 0 9]
# 進行預測結果指標統計 統計每一類別的預測準確率、召回率、F1分數
print(classification_report(test_y, y_test_pred))
precision recall f1-score support
0 1.00 0.93 0.96 14
1 0.86 0.86 0.86 22
2 0.93 1.00 0.97 14
3 1.00 1.00 1.00 22
4 1.00 0.81 0.89 21
5 1.00 1.00 1.00 16
6 0.94 0.94 0.94 18
7 1.00 0.94 0.97 18
8 0.80 0.84 0.82 19
9 0.75 0.94 0.83 16
accuracy 0.92 180
macro avg 0.93 0.93 0.93 180
weighted avg 0.93 0.92 0.92 180
# 計算混淆矩陣
C2 = confusion_matrix(test_y, y_test_pred)
# 打混淆矩陣
print(C2)
# 將混淆矩陣以熱力圖的防線顯示
sns.set()
f, ax = plt.subplots()
# 畫熱力圖
sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax)
# 標題
ax.set_title('confusion matrix')
# x軸為預測類別
ax.set_xlabel('predict')
# y軸實際類別
ax.set_ylabel('true')
plt.show()
[[13 0 0 0 0 0 1 0 0 0]
[ 0 19 1 0 0 0 0 0 0 2]
[ 0 0 14 0 0 0 0 0 0 0]
[ 0 0 0 22 0 0 0 0 0 0]
[ 0 0 0 0 17 0 0 0 3 1]
[ 0 0 0 0 0 16 0 0 0 0]
[ 0 1 0 0 0 0 17 0 0 0]
[ 0 0 0 0 0 0 0 17 0 1]
[ 0 2 0 0 0 0 0 0 16 1]
[ 0 0 0 0 0 0 0 0 1 15]]
LDA演演算法的主要優點:
LDA演演算法的主要缺點:
本專案連結:https://www.heywhale.com/home/column/64141d6b1c8c8b518ba97dcc
參考連結:https://tianchi.aliyun.com/course/278/3426
本人最近打算整合ML、DRL、NLP等相關領域的體系化專案課程,方便入門同學快速掌握相關知識。宣告:部分專案為網路經典專案方便大家快速學習,後續會不斷增添實戰環節(比賽、論文、現實應用等)。
上述對於你掌握後的期許:
這三塊領域耦合情況比較大,後續會通過比如:搜尋推薦系統整個專案進行耦合,各項演演算法都會耦合在其中。舉例:知識圖譜就會用到(圖演演算法、NLP、ML相關演演算法),搜尋推薦系統(除了該領域召回粗排精排重排混排等演演算法外,還有強化學習、知識圖譜等耦合在其中)。餅畫的有點大,後面慢慢實現。