딥러닝 모델 학습에서 학습률은 아주 중요하다. 하지만 학습 단계에 따라 최적의 학습률은 계속해서 달라진다. 따라서 scheduler를 사용해 자동을 학습률을 조정하기도 한다. 논문을 읽다 보면 아래와 같은 방법을 쓰기도 한다.
- N번째 iteration까지 Linear Warm-up
- Loss가 수렴하지 않을 때, 학습률을 10으로 나눔
Pytorch는 위 기능을 기본으로 제공하지 않기 때문에 직접 구현해 사용해야 한다.
참고로, warmup은 학습 초기에 학습률을 서서히 증가시키는 방법이다. 초기에 가중치가 급격하게 변하는 현상을 방지한다.
class WarmupScheduler:
"""Warmup learning rate and dynamically adjusts learning rate based on training loss.
:param optimizer: torch optimizer
:param initial_lr: initial learning rate
:param min_lr: minimum learning rate
:param warmup_steps: number of warmup steps
:param decay_factor: decay factor
"""
def __init__(
self, optimizer, initial_lr, min_lr=1e-6, warmup_steps=10, decay_factor=10
):
self.optimizer = optimizer
self.initial_lr = initial_lr
self.min_lr = min_lr
self.warmup_steps = warmup_steps
self.decay_factor = decay_factor
assert self.warmup_steps > 0, "Warmup steps must be greater than 0"
assert self.decay_factor > 1, "Decay factor must be greater than 1"
self.global_step = 0
self.best_loss = float("inf")
# Store initial learning rates
for param_group in self.optimizer.param_groups:
param_group["lr"] = 0 # Start with 0 LR
def step(self, loss):
"""Update learning rate based on current loss."""
self.global_step += 1
if self.global_step <= self.warmup_steps:
# Linear warmup
warmup_lr = self.initial_lr * (self.global_step / self.warmup_steps)
for param_group in self.optimizer.param_groups:
param_group["lr"] = warmup_lr
else:
# Check if loss increased
if loss > self.best_loss:
for param_group in self.optimizer.param_groups:
new_lr = max(param_group["lr"] / self.decay_factor, self.min_lr)
param_group["lr"] = new_lr
self.best_loss = min(self.best_loss, loss)
def get_lr(self):
"""Return current learning rates."""
return [param_group["lr"] for param_group in self.optimizer.param_groups]
위 예시는 첫 warmup_steps epoch 동안 학습률을 0에서 서서히 증가시킨다. 이후에는 loss를 확인해, loss가 수렴하지 않을 때 학습률을 10씩 나눠가며 조정한다.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = WarmupScheduler(optimizer, warmup_steps=1000, initial_lr=1e-3)
# Train
for epoch in range(epochs):
for batch, label in dataloader:
output = model(batch)
loss = compute_loss(output, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Update learning rate based on loss
scheduler.step(loss.item())
print(f"Epoch {epoch}: LR = {scheduler.get_lr()}")