Pytorch에서 EarlyStopping

EarlyStopping은 특정 평가 지표가 증가·감소하는 현상을 보였을 때, 모델이 over-fitting 되었다고 판단하여 학습을 중단하는 것을 뜻한다. tensorflowtransformers에서는 자체적으로 EarlyStopping을 지원하지만 Pytorch의 경우 직접 구현한 객체를 사용해야 한다. 

class EarlyStopping(object):
    """Stop training when loss does not decrease

    :param patience: number of epochs to wait before stopping
    :param path_to_save: path to save the best model
    """

    def __init__(self, patience, path_to_save):
        self._min_loss = float("inf")
        self._patience = patience
        self._path = path_to_save
        self.__check_point = None
        self.__counter = 0

    def should_stop(self, loss, model=None, epoch=None):
        """Check if training should stop.
        If 'model' or 'epoch' is None, the checkpoint will not be saved.

        :param loss: current loss
        :param model: model to save as checkpoint
        :param epoch: current epoch to mark check point
        :return: True if training should stop, False otherwise
        """
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            if (model is not None) and (epoch is not None):
                self.__check_point = epoch
                torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter == self._patience:
                return True
        return False

    def load(self, weights_only=True):
        """Load best model weights

        :param weights_only: load only weights (default: True)
        :return: best model weights
        """
        return torch.load(self._path, weights_only=weights_only)

    @property
    def check_point(self):
        """Return check point index

        :return: check point index
        """
        if self.__check_point is None:
            raise ValueError("No check point is saved!")
        return self.__check_point

참고: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

stackoverflow의 답변을 참고하여 최종적으로 loss가 가장 낮은 모델이 저장되도록 수정한 객체이다.

patience는 바로 중단하지 않고 대기할 epoch(step)이다. Loss가 0.5313 → 0.5314 → 0.5311와 같이 일시적으로 증가하였다가 다시 감소할 수 있기 때문에 patience를 사용한다.

lossValidation-Loss 값이다. 만약 loss가 아닌 accuracy, f1-score 등 평가 지표를 사용하기 위해서는 부등호 및 _min_loss를 변경해 주어야 한다. 

# 객체 선언
early_stopper = EarlyStopping(patience=3)

# 학습
for epoch in epochs:
    model.train()
    for data in train_loader:
        output = model(**data)
        # 생략...

    val_loss = evaluate(model, criterion, val_loader)

    if early_stopper.should_stop(val_loss, model, epoch):
        print(f"EarlyStopping: [Epoch: {early_stopper.check_point}]")
        break

# 학습된 모델 가중치 불러오기
weights = early_stopper.load()
model.load_state_dict(weights)

일반적인 pytorch의 학습 과정이다. 한 epoch이 종료된 후 should_stop을 호출하여 학습을 종료할지 결정하고, load저장된 모델 가중치를 불러온다.


변형:

class EarlyStopping(object):
    """score로 stopping하기"""
    def __init__(self, patience, save_path, eps):
        self._max_score = -1
        self._patience = patience
        self._path = save_path
        self._eps = eps
        self.__counter = 0

    def should_stop(self, model, score):
        if score > self._max_score:
            self._max_score = score
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif score < self._max_score + self._eps:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False