EarlyStopping은 특정 평가 지표가 증가·감소하는 현상을 보였을 때, 모델이 over-fitting 되었다고 판단하여 학습을 중단하는 것을 뜻한다. tensorflow나 transformers에서는 자체적으로 EarlyStopping을 지원하지만 Pytorch의 경우 직접 구현한 객체를 사용해야 한다.
class EarlyStopping(object):
"""Stop training when loss does not decrease
:param patience: number of epochs to wait before stopping
:param path_to_save: path to save the best model
"""
def __init__(self, patience, path_to_save):
self._min_loss = float("inf")
self._patience = patience
self._path = path_to_save
self.__check_point = None
self.__counter = 0
def should_stop(self, loss, model=None, epoch=None):
"""Check if training should stop.
If 'model' or 'epoch' is None, the checkpoint will not be saved.
:param loss: current loss
:param model: model to save as checkpoint
:param epoch: current epoch to mark check point
:return: True if training should stop, False otherwise
"""
if loss < self._min_loss:
self._min_loss = loss
self.__counter = 0
if (model is not None) and (epoch is not None):
self.__check_point = epoch
torch.save(model.state_dict(), self._path)
elif loss > self._min_loss:
self.__counter += 1
if self.__counter == self._patience:
return True
return False
def load(self, weights_only=True):
"""Load best model weights
:param weights_only: load only weights (default: True)
:return: best model weights
"""
return torch.load(self._path, weights_only=weights_only)
@property
def check_point(self):
"""Return check point index
:return: check point index
"""
if self.__check_point is None:
raise ValueError("No check point is saved!")
return self.__check_point
참고: https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
stackoverflow의 답변을 참고하여 최종적으로 loss가 가장 낮은 모델이 저장되도록 수정한 객체이다.
patience는 바로 중단하지 않고 대기할 epoch(step)이다. Loss가 0.5313 → 0.5314 → 0.5311와 같이 일시적으로 증가하였다가 다시 감소할 수 있기 때문에 patience를 사용한다.
loss는 Validation-Loss 값이다. 만약 loss가 아닌 accuracy, f1-score 등 평가 지표를 사용하기 위해서는 부등호 및 _min_loss를 변경해 주어야 한다.
# 객체 선언
early_stopper = EarlyStopping(patience=3)
# 학습
for epoch in epochs:
model.train()
for data in train_loader:
output = model(**data)
# 생략...
val_loss = evaluate(model, criterion, val_loader)
if early_stopper.should_stop(val_loss, model, epoch):
print(f"EarlyStopping: [Epoch: {early_stopper.check_point}]")
break
# 학습된 모델 가중치 불러오기
weights = early_stopper.load()
model.load_state_dict(weights)
일반적인 pytorch의 학습 과정이다. 한 epoch이 종료된 후 should_stop을 호출하여 학습을 종료할지 결정하고, load로 저장된 모델 가중치를 불러온다.
변형:
class EarlyStopping(object):
"""score로 stopping하기"""
def __init__(self, patience, save_path, eps):
self._max_score = -1
self._patience = patience
self._path = save_path
self._eps = eps
self.__counter = 0
def should_stop(self, model, score):
if score > self._max_score:
self._max_score = score
self.__counter = 0
torch.save(model.state_dict(), self._path)
elif score < self._max_score + self._eps:
self.__counter += 1
if self.__counter >= self._patience:
return True
return False