Pytorch에서 EarlyStopping

EarlyStopping은 특정 평가 지표가 증가·감소하는 현상을 보였을 때, 모델이 over-fitting 되었다고 판단하여 학습을 중단하는 것을 뜻한다. tensorflowtransformers에서는 자체적으로 EarlyStopping을 지원하지만 Pytorch의 경우 직접 구현한 객체를 사용해야 한다.  


EarlyStopping 적용 예시

'학습 시작'을 눌러
실행해 보세요.
Patience:

 


Pytorch로 구현

import torch
import numpy as np 


class EarlyStopping(object):
    def __init__(self, patience=2, save_path="model.pth"):
        self._min_loss = np.inf
        self._patience = patience
        self._path = save_path
        self.__counter = 0

    def should_stop(self, model, loss):
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False
   
    def load(self, model):
        model.load_state_dict(torch.load(self._path))
        return model
    
    @property
    def counter(self):
        return self.__counter

참고: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

stackoverflow의 답변을 참고하여 최종적으로 loss가 가장 낮은 모델이 저장되도록 수정한 객체이다.

patience는 바로 중단하지 않고 대기할 epoch(step)이다. Loss가 0.5313 → 0.5314 → 0.5311와 같이 일시적으로 증가하였다가 다시 감소할 수 있기 때문에 patience를 사용한다.

save_path현재까지 잘 학습되었다고 판단되는 모델을 저장할 경로이며, model학습된 모델이다. 

lossValidation-Loss 값이다. 만약 loss가 아닌 accuracy, f1-score 등 평가 지표를 사용하기 위해서는 부등호 및 _min_loss를 변경해 주어야 한다. 

# 객체 선언
early_stopper = EarlyStopping(patience=3)

# 학습
for epoch in epochs:
    model.train()
    for data in train_loader:
        output = model(**data)
        # 생략...

    val_loss = evaluate(model, criterion, val_loader)

    if early_stopper.should_stop(model, val_loss):
        print(f"EarlyStopping: [Epoch: {epoch - early_stopper.counter}]")
        break

# 학습된 모델 불러오기
model = early_stopper.load(model)

일반적인 pytorch의 학습 과정이다. 한 epoch이 종료된 후 should_stop을 호출하여 학습을 종료할지 결정하고, load저장된 모델을 불러온다.


변형:

class EarlyStopping(object):
    """score로 stopping하기"""
    def __init__(self, patience, save_path, eps):
        self._max_score = -1
        self._patience = patience
        self._path = save_path
        self._eps = eps
        self.__counter = 0

    def should_stop(self, model, score):
        if score > self._max_score:
            self._max_score = score
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif score < self._max_score + self._eps:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False