EarlyStopping은 특정 평가 지표가 증가·감소하는 현상을 보였을 때, 모델이 over-fitting 되었다고 판단하여 학습을 중단하는 것을 뜻한다. tensorflow나 transformers에서는 자체적으로 EarlyStopping을 지원하지만 Pytorch의 경우 직접 구현한 객체를 사용해야 한다.
EarlyStopping 적용 예시
'학습 시작'을 눌러
실행해 보세요.
실행해 보세요.
Patience:
Pytorch로 구현
import torch
import numpy as np
class EarlyStopping(object):
def __init__(self, patience=2, save_path="model.pth"):
self._min_loss = np.inf
self._patience = patience
self._path = save_path
self.__counter = 0
def should_stop(self, model, loss):
if loss < self._min_loss:
self._min_loss = loss
self.__counter = 0
torch.save(model.state_dict(), self._path)
elif loss > self._min_loss:
self.__counter += 1
if self.__counter >= self._patience:
return True
return False
def load(self, model):
model.load_state_dict(torch.load(self._path))
return model
@property
def counter(self):
return self.__counter
참고: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
stackoverflow의 답변을 참고하여 최종적으로 loss가 가장 낮은 모델이 저장되도록 수정한 객체이다.
patience는 바로 중단하지 않고 대기할 epoch(step)이다. Loss가 0.5313 → 0.5314 → 0.5311와 같이 일시적으로 증가하였다가 다시 감소할 수 있기 때문에 patience를 사용한다.
save_path는 현재까지 잘 학습되었다고 판단되는 모델을 저장할 경로이며, model은 학습된 모델이다.
loss는 Validation-Loss 값이다. 만약 loss가 아닌 accuracy, f1-score 등 평가 지표를 사용하기 위해서는 부등호 및 _min_loss를 변경해 주어야 한다.
# 객체 선언
early_stopper = EarlyStopping(patience=3)
# 학습
for epoch in epochs:
model.train()
for data in train_loader:
output = model(**data)
# 생략...
val_loss = evaluate(model, criterion, val_loader)
if early_stopper.should_stop(model, val_loss):
print(f"EarlyStopping: [Epoch: {epoch - early_stopper.counter}]")
break
# 학습된 모델 불러오기
model = early_stopper.load(model)
일반적인 pytorch의 학습 과정이다. 한 epoch이 종료된 후 should_stop을 호출하여 학습을 종료할지 결정하고, load로 저장된 모델을 불러온다.
변형:
class EarlyStopping(object):
"""score로 stopping하기"""
def __init__(self, patience, save_path, eps):
self._max_score = -1
self._patience = patience
self._path = save_path
self._eps = eps
self.__counter = 0
def should_stop(self, model, score):
if score > self._max_score:
self._max_score = score
self.__counter = 0
torch.save(model.state_dict(), self._path)
elif score < self._max_score + self._eps:
self.__counter += 1
if self.__counter >= self._patience:
return True
return False