[深度學習] keras的EarlyStopping使用與技巧

Early Stopping是什麼

具體EarlyStopping的使用請參考官方文檔源代碼
EarlyStopping是Callbacks的一種,callbacks用於指定在每個epoch開始和結束的時候進行哪種特定操作。Callbacks中有一些設置好的接口,可以直接使用,如’acc’, 'val_acc’, ’loss’ 和 ’val_loss’等等。
EarlyStopping則是用於提前停止訓練的callbacks。具體地,可以達到當訓練集上的loss不在減小(即減小的程度小於某個閾值)的時候停止繼續訓練。
 

爲什麼要用

爲了獲得性能良好的神經網絡,網絡定型過程中需要進行許多關於所用設置(超參數)的決策。超參數之一是定型週期(epoch)的數量:亦即應當完整遍歷數據集多少次(一次爲一個epoch)?如果epoch數量太少,網絡有可能發生欠擬合(即對於定型數據的學習不夠充分);如果epoch數量太多,則有可能發生過擬合(即網絡對定型數據中的「噪聲」而非信號擬合)。

早停法旨在解決epoch數量需要手動設置的問題。它也可以被視爲一種能夠避免網絡發生過擬合的正則化方法(與L1/L2權重衰減和丟棄法類似)。

根本原因就是因爲繼續訓練會導致測試集上的準確率下降。
那繼續訓練導致測試準確率下降的原因猜測可能是1. 過擬合 2. 學習率過大導致不收斂 3. 使用正則項的時候,Loss的減少可能不是因爲準確率增加導致的,而是因爲權重大小的降低。

原理

  • 將數據分爲訓練集和驗證集
  • 每個epoch結束後(或每N個epoch後): 在驗證集上獲取測試結果,隨着epoch的增加,如果在驗證集上發現測試誤差上升,則停止訓練;
  • 將停止之後的權重作爲網絡的最終參數。

這種做法很符合直觀感受,因爲精度都不再提高了,在繼續訓練也是無益的,只會提高訓練的時間。那麼該做法的一個重點便是怎樣才認爲驗證集精度不再提高了呢?並不是說驗證集精度一降下來便認爲不再提高了,因爲可能經過這個Epoch後,精度降低了,但是隨後的Epoch又讓精度又上去了,所以不能根據一兩次的連續降低就判斷不再提高。一般的做法是,在訓練的過程中,記錄到目前爲止最好的驗證集精度,當連續10次Epoch(或者更多次)沒達到最佳精度時,則可以認爲精度不再提高了。

直觀理解

Early Stopping

最優模型是在垂直虛線的時間點保存下來的模型,即處理測試集時準確率最高的模型。

爲什麼能減小過擬合

當還未在神經網絡運行太多迭代過程的時候,w參數接近於0,因爲隨機初始化w值的時候,它的值是較小的隨機值。當你開始迭代過程,w的值會變得越來越大。到後面時,w的值已經變得十分大了。所以early stopping要做的就是在中間點停止迭代過程。我們將會得到一箇中等大小的w參數,會得到與L2正則化相似的結果,選擇了w參數較小的神經網絡。

Early Stopping的優缺點

優點:只運行一次梯度下降,我們就可以找出w的較小值,中間值和較大值。而無需嘗試L2正則化超級參數lambda的很多值。

缺點:不能獨立地處理以上兩個問題,使得要考慮的東西變得複雜。舉例如下:

沒有采取不同的方式來解決優化損失函數和降低方差這兩個問題,而是用一種方法同時解決兩個問題 ,結果就是要考慮的東西變得更復雜。之所以不能獨立地處理,因爲如果你停止了優化代價函數,你可能會發現代價函數的值不夠小,同時你又不希望過擬合。

EarlyStopping的使用與技巧

一般是在model.fit函數中調用callbacks,fit函數中有一個參數爲callbacks。注意這裏需要輸入的是list類型的數據,所以通常情況只用EarlyStopping的話也要是[EarlyStopping()]

EarlyStopping的參數:

  • monitor: 監控的數據接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情況下如果有驗證集,就用’val_acc’或者’val_loss’。但是因爲筆者用的是5折交叉驗證,沒有單設驗證集,所以只能用’acc’了。
  • min_delta:增大或減小的閾值,只有大於這個部分纔算作improvement。這個值的大小取決於monitor,也反映了你的容忍程度。例如筆者的monitor是’acc’,同時其變化範圍在70%-90%之間,所以對於小於0.01%的變化不關心。加上觀察到訓練過程中存在抖動的情況(即先下降後上升),所以適當增大容忍程度,最終設爲0.003%。
  • patience:能夠容忍多少個epoch內都沒有improvement。這個設置其實是在抖動和真正的準確率下降之間做tradeoff。如果patience設的大,那麼最終得到的準確率要略低於模型可以達到的最高準確率。如果patience設的小,那麼模型很可能在前期抖動,還在全圖搜索的階段就停止了,準確率一般很差。patience的大小和learning rate直接相關。在learning rate設定的情況下,前期先訓練幾次觀察抖動的epoch number,比其稍大些設置patience。在learning rate變化的情況下,建議要略小於最大的抖動epoch number。筆者在引入EarlyStopping之前就已經得到可以接受的結果了,EarlyStopping算是錦上添花,所以patience設的比較高,設爲抖動epoch number的最大值。
  • mode: 就’auto’, ‘min’, ‘,max’三個可能。如果知道是要上升還是下降,建議設置一下。筆者的monitor是’acc’,所以mode=’max’。

min_delta和patience都和「避免模型停止在抖動過程中」有關係,所以調節的時候需要互相協調。通常情況下,min_delta降低,那麼patience可以適當減少;min_delta增加,那麼patience需要適當延長;反之亦然。

 

class RocAucMetricCallback(keras.callbacks.Callback):
    def __init__(self, predict_batch_size=1024):
        super(RocAucMetricCallback, self).__init__()
        self.predict_batch_size = predict_batch_size

    def on_batch_begin(self, batch, logs={}):
        pass

    def on_batch_end(self, batch, logs={}):
        pass

    def on_train_begin(self, logs={}):
        if not ('val_roc_auc' in self.params['metrics']):
            self.params['metrics'].append('val_roc_auc')

    def on_train_end(self, logs={}):
        pass

    def on_epoch_begin(self, epoch, logs={}):
        pass

    def on_epoch_end(self, epoch, logs={}):
        logs['roc_auc'] = float('-inf')
        if (self.validation_data):
            logs['roc_auc'] = roc_auc_score(self.validation_data[1],
                                            self.model.predict(self.validation_data[0],
                                                               batch_size=self.predict_batch_size))
            print('ROC_AUC - epoch:%d - score:%.6f' % (epoch + 1, logs['roc_auc']))
my_callbacks = [
        RocAucMetricCallback(),  # include it before EarlyStopping!
        EarlyStopping(monitor='roc_auc', patience=20, verbose=2, mode='max')
    ]

    mlp.fit(X_train_pre, y_train_pre,
            batch_size=512,
            epochs=500,
            class_weight="auto",
            callbacks=my_callbacks,
            validation_data=(X_train_pre_val, y_train_pre_val))

 

擴充

如果不用early stopping降低過擬合,另一種方法就是L2正則化,但需嘗試L2正則化超級參數λ的很多值,個人更傾向於使用L2正則化,嘗試許多不同的λ值。