從圖表秒懂機器學習模型的原理:以 matplotlib 視覺化 scikit-learn 的分類器(KNN、邏輯斯迴歸、SVM、決策樹、隨機森林)

如果你在學機器學習,應該會聽過「邏輯斯迴歸」。邏輯斯迴歸是個好用的分類器演算法,而想在 scikit-learn 中匯入、訓練和使用它也相當簡單。然而,邏輯斯迴歸是如何分類資料的,就需要一點技巧來解釋了。

很多人會從數學式子切入,但對如在下我這種數學不好的人來說,抽象的方程式永遠非常難懂。而有些書或網站會附上邏輯斯函數的圖表,但它們通常跟後面的程式範例沒有直接關係(此外這些書或網站經常也只是抄襲其他人的範例)。最後,已嚴然成為 Python 機器學習代名詞的 scikit-learn 儘管在其網站上有很多視覺化範例,但它們大多極其複雜,要拿來應用實在是有難度。

而這就是這篇文章的出發點:記錄我對於 scikit-learn 幾種模型之視覺化的簡單研究成果。將機器學習模型視覺化,對於教學其實有很大的好處 — — 這能讓人們從更直覺的方式來理解模型是如何分類資料,而不必只能從抽象的數學式子去想像,還能跟資料產生連結。此外,我也需要找個地方來記錄這些程式碼,以後就不必花時間重寫了。

我會稍微解釋一下某些東西,但這畢竟不是新手教學文,我仍假設你對其他套件(NumPy, matplotlib 等)有些基礎概念。

準備(以及釐清)資料

首先,我們自然需要資料。而為了能夠在二維平面上繪圖,這個資料必須只有 2 個特徵(自變數)和 2 個標籤(分類)。然而,我並不想用 scikit-learn 的隨機資料產生功能(例如 make_blobs() 或 make_moons())。

剛好,scikit-learn 內建的乳癌資料集雖有 30 個特徵,但實際用 PCA(principal component analysis,主成分分析)來篩選後,可發現變異度最大的 2 個就能解釋原資料的 99.8% 變異:

篩選特徵

PCA 是一種非監督式機器學習演算法,最主要的用處之一就是用來篩選資料,把變異解釋度最大的 N 筆資料留下來。這麼一來就能在維持差不多預測準確率的情況下減少機器學習的訓練時間。

下面的程式碼複雜一點,列出所有特徵的變異解釋能力和其名稱,好讓我們知道留下來的是哪些資料:

可見 feature 23(worst area)和 feature 3(mean area)就是前面 PCA 保留下來的前 2 大特徵。

最後,此資料集有 569 筆資料,分類又只有 0 和 1,剛好是二元分類問題。下面就來整理整理資料,分割出訓練集與測試集:

在 matplotlib 的許多繪圖功能中,你可以用參數 c 給資料指定額外的值,然後用 cmap 指定一個 color map,這樣 c 值的差異就會用顏色反映出來。

藉由顏色的識別,可看到資料確實大致分成兩類,但究竟哪個是標籤 0 或 1 呢?哪一個又代表良性或惡性腫瘤?

分類內容的探討

下面來費點功夫,把各標籤的名稱跟數量也一併標出來(這邊的過程稍微複雜一點,就不多解釋了):

對照 scikit-learn 網站上對這資料集的說明:

可確認分類 0(右側綠色)代表惡性腫瘤,分類 1(左側灰色)則代表良性腫瘤。我自己一開始也跟很多網站一樣搞反了,所以資料意義的確認上真是不可不慎哪。

KNN

KNN(K-nearest neighbors)是所有機器學習模型中最好懂的:找出 K 個跟測試資料最接近的點,再統計這些點的分類做為預測結果。當然,你也可以將距離當成權重,使較近的點具有更強的影響力。

首先來看 KNN 對測試集的預測標籤和實際標籤,以視覺化比較的結果:

現在點外側的顏色對應到預測標籤,內側則對應實際標籤。這麼一來,就很容易看出分類的效果,以及兩組資料的邊界有些點預測錯誤。

k-neighbors 的繪製

那麼,要怎麼實際看到 KNN 挑選出來的最近 k 鄰呢?這就可以用模型的 kneighbors() 來取得這 k 個點的距離和索引(目前用不到距離,但也許你能拿來做些什麼)。注意這裡的索引是訓練集的索引,畢竟 KNN 是拿訓練集的點作為參考。

可以把 KNeighborsClassifier() 的 weights 參數改成 ‘distance’ 來使用距離當權重。下面是 k = 7 且 weights = ‘distance’ 時跑出來的結果:

可見 7 個值有 4 個是分類 1(淺藍色),機率應為 57%,但分類 0 有一個點比較接近測試資料,使分類 0 加權後提高了機率。

邏輯斯迴歸

邏輯斯迴歸(logistic regression)是二元分類器,藉由邏輯斯函數來將特徵資料投射到介於 0 到 1 之間的值,來判斷資料是否屬於某個分類。如果要預測的標籤超過 2 個以上,則可使用「一對多」的方式來預測每個標籤的機率,再決定最終預測結果(這便是 sciki-learn 的邏輯斯迴歸的做法;沒搞錯的話,這種版本的邏輯斯函數稱為 softmax)。

sciki-learn 邏輯斯迴歸模型的 coef_ 和 intercept_ 屬性會傳回係數跟截距,用這些資料算出的線性方程式再代入 scipy 的 expit()(即 logistic sigmoid 函數)後就會得到邏輯斯函數。(當然你也可以不用 expit() 而自己套公式,但我們就省點麻煩吧。)

雖然前面說要用 2 個特徵,但既然邏輯斯函數會把特徵映射到新的 Y 軸,只用 1 個特徵來做會比較好理解(反正前面已經看到,即使 1 個特徵也具備 98.2% 的變異解釋度):

上面可以看到,測試集資料換算後剛好都落在邏輯斯函數上,而只要根據每個點的 Y 軸值是否大於 0.5,就能預測它是否屬於標籤 1(既然是二元分類,不是 1 就一定是 0)。我們也能藉由內外顏色來觀察哪些資料是預測錯誤的。

事實上,如果對前面程式中的 y_t 做四捨五入,它就會投射到 Y = 1 或 Y = 0,變成一般教材中常見的邏輯斯迴歸圖:

有意思的是,y_t.round() 得到的內容會和 predict(呼叫 model.predict() 的結果)完全一致。而若拿 y_t 和 pred_prob(呼叫 model.predict_proba() 傳回的機率值)來比較,也會發現 y_t 的內容正是 pred_prob 對於標籤 1 的預測機率:

會得到

這證實了我們在程式內求出的邏輯斯函數,跟 scikit-learn 本身算出的是一樣的。

多特徵的邏輯斯函數

現在我們要回到 2 個特徵的問題;用 2 個特徵訓練完模型時,model.coef_[0] 得到的係數就會是 2 個,model.intercept_[0] 則仍是一個。你可以拿它們各別求出 2 條邏輯斯函數,不過這麼一來其實跟前面做的事是一樣的。

反而,我們稍微改一下算法,使得可以用 2 個特徵來直接算出標籤 1 的機率:

就和前面一樣,這個結果和你呼叫 model.predict()、model.predict_proba() 的結果是一模一樣的。

scikit-learn 的邏輯斯迴歸對二元分類會採用 ovr(one vs. rest,一對多)策略,對 3 個以上的標籤則會用 multinomial(即 softmax),除非你用 multi_class 參數來強制指定。我想多元預測的函數求法也是一樣的,不過這就不是這篇文要討論的主題了。

邏輯斯函數的決策邊界

那麼,邏輯斯迴歸的資料預測,對最開始的原始資料會有何影響呢?下面就來在資料當中畫出決策邊界(同樣用係數和截距求出該線的方程式)。比對一下前面的圖表,便不難看出為何有些資料會預測錯誤:

SVM

支援向量機(support vector machine)的分類效果跟邏輯斯迴歸很像,原理卻大不同。SVM 是藉由將資料投射到更高維度的方式,來找出能夠分隔資料的超平面(可想像成馬路的中線),這個超平面兩側到某分類之資料的邊界(margin,人行道邊緣)必須盡量拉大。這個分界可以是線性的,也可以藉由 kernel 函式轉換來求出非線性的超平面 — — 我就點到為止,再講下去就是讓人聽不懂的數學啦。

下面先來看線性版本,畫出超平面的基本原理和前面的邏輯斯迴歸很像,只是多了邊界而已:

至於非線性 SVM,畫法就不太一樣了。基本上這在 scikit-learn 官網上有蠻明確的範例,我只有稍微簡化。簡單來說,就是直接使用模型本身的 decision_function() 來取得超平面跟邊界的函數:

SVM 的 kernel 有幾種可以選擇,預設是 rbf(徑向基函數)。linear 就是前面的線性版本。至於 poly 和 sigmoid 對這份資料的分類效果不佳,所以就不示範了,這裡也不討論各個 kernel 的原理。

繪製支援向量

最後,SVM() 的 support_vectors_ 屬性 — — LinearSVC() 沒有這玩意 — — 會包含一系列座標,就是訓練集中位於邊界內的資料點,你可藉此看看支援向量的視覺化:

你也可以用 model.support_ 取回這些訓練集資料的索引,和前面 KNN 的做法很像。

理論上,所謂的支援向量是指能夠用來定義最大邊界的資料點(這麼做的 SVM 即 hard-margin SVM),但資料分界不夠明確時效果就很差,所以現在一般會使用所謂的 soft-margin SVM,也就是能容忍誤差或「雜訊」。這便是為何有許多點會落在邊界之內。這樣做是必須的,畢竟兩筆資料有一部分為重疊。

對 SVM() 來說,你可以用參數 C 和 gamma 來微調分界方式:C 即為用來控制 soft-margin 的損失函數,設得越高對誤差的容忍就越小,gamma 則代表各別訓練資料點的影響能力,值越大曲線會越曲折。這些知識牽涉到數學,在此就不多介紹啦。

但下面來做個簡單實驗,改用 make_circles() 來產生兩群不重疊、形成圓圈狀的資料,然後改變 C 與 gamma 參數。這麼一來,你就能看到支援向量剛好都落在 margin 上,很類似 hard-margin SVM 的做法:

支援向量在更高維度空間之探討

最後,我們來看看另一個有趣的玩意。前面提到 SVM 會將資料投射到更高維度來尋找更明顯的分界線,但這種投射過程看起來是什麼樣子呢?

下面的程式使用 scikit-learn 的 RBF 函式來換算出訓練集資料點的 Z 軸,好把它們投射在三維座標軸裡,並也把支援向量畫出來:

輸入長度為 N 的資料時,不管有幾個特徵,RBF 函數都會傳回一個 N x N 陣列。我將每一行子陣列的值加總,用這個值來代表 Z 軸。

身為數學苦手,且不管是 scikit-learn 或 scipy 的 RBF 功能,要如何拿來換算資料的說明文件都很少。但如上所見,兩個圈圈在三維空間上下明顯分開來了 — — 如果你從 Z 軸上方直直往下看,就會跟前面的二維圖一樣 。

此外,在三維空間的版本中,可見兩群資料之間最靠近的幾個點就是支援向量。也就是說,SVM 找到的超平面(在三維空間中是個二維曲面)從支援向量之間的位置切過去,實現了分類的目的。

當然,要實際畫出超平面本身就困難多了,畫出來說不定還讓電腦跑得很慢,所以這篇文就做到這裡吧。

下面是把乳癌資料集的兩大特徵套用 RBF 函數,並標出支援向量的結果:

決策樹

決策樹(decision tree)藉由建立樹狀的節點結構來判定資料分類,簡單又有效。但相較於前面的模型,決策樹的圖形化方式就比較不同了。

scikit-learn 自己提供了繪製決策樹的功能(這你可能在網站或書上看過)。此外,為了能在樹的節點中顯示合理的數值跟標籤,下面我們就不使用 PCA 與資料標準化,用完整的 30 個特徵下去訓練:

文字的輸出結果為

留意 export_text() 的 feature_names 參數只吃 list,所以直接給它 ndarray 會出錯,要先轉換一下。

產生的圖則為:

文字版比較單純,但圖片版會正確顯示用來判定的特徵以及標籤/分類的名稱。此外,你也能看到框的顏色反映了對某分類的判定機率。

比較麻煩的是,如果決策樹的層級變多,tree.plot_tree() 畫出來的文字和框就會變得難以閱讀,而這功能也沒有什麼調整空間。這時你可以選擇把它畫成一張更大的圖,並直接輸出到檔案:

隨機森林

隨機森林(random forest)就是用一群決策樹來預測。靠著俗稱的「群眾智慧」現象,隨機森林不僅能進一步提高準確率,還能避免單一決策樹可能的過度配適(overfitting)問題,是集成學習(ensemble learning)的代表之一。

RandomForestClassifier() 模型的 estimators_ 屬性會包含它所有的決策樹(不指定時預設為 100 棵),所以只要把這些樹用 tree.plot_tree 全部畫出來即可。當然,為了能輸出成圖形,這裡我們還是限制了樹的層級跟數量:

ROC

ROC(receiver operating characteristic)曲線也是個很常看到的圖表,代表模型預測時的真陽性率(TPR)和假陽性率(FPR)的關係。簡單來說,這曲線離對角線越遠、畫出的區域(auc 或 area under curve)越大,就代表預測效果越好。

因此,本文最後就來以這個收尾吧,畫一下前面幾個模型的 ROC 曲線。

scikit-learn 裡畫 ROC 最簡單的方式是使用 plot_roc_curve()(ROC 繪圖只適用於二元分類問題),不過有個小問題:這功能會開自己的繪圖視窗。解決辦法是把它嵌進一個子圖表內,這樣才能調整大小。除此以外,plot_roc_curve() 本質上似乎是 plt.plot(),所以能傳同樣的參數進去。

你也能看到 plot_roc_curve() 會產生自己的圖例,算是蠻方便的。

I just like to write weird stuff that have very little to do with my actual work. My normal blog is https://krantasblog.blogspot.com.

I just like to write weird stuff that have very little to do with my actual work. My normal blog is https://krantasblog.blogspot.com.