Pytorch Warmup + Scheduler

딥러닝 모델 학습에서 학습률은 아주 중요하다. 하지만 학습 단계에 따라 최적의 학습률은 계속해서 달라진다. 따라서 scheduler를 사용해 자동을 학습률을 조정하기도 한다. 논문을 읽다 보면 아래와 같은 방법을 쓰기도 한다.

  • N번째 iteration까지 Linear Warm-up
  • Loss가 수렴하지 않을 때, 학습률을 10으로 나눔

Pytorch는 위 기능을 기본으로 제공하지 않기 때문에 직접 구현해 사용해야 한다.

참고로, warmup은 학습 초기에 학습률을 서서히 증가시키는 방법이다. 초기에 가중치가 급격하게 변하는 현상을 방지한다.

class WarmupScheduler:
    """Warmup learning rate and dynamically adjusts learning rate based on training loss.

    :param optimizer: torch optimizer
    :param initial_lr: initial learning rate
    :param min_lr: minimum learning rate
    :param warmup_steps: number of warmup steps
    :param decay_factor: decay factor
    """

    def __init__(
        self, optimizer, initial_lr, min_lr=1e-6, warmup_steps=10, decay_factor=10
    ):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.decay_factor = decay_factor

        assert self.warmup_steps > 0, "Warmup steps must be greater than 0"
        assert self.decay_factor > 1, "Decay factor must be greater than 1"

        self.global_step = 0
        self.best_loss = float("inf")

        # Store initial learning rates
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = 0  # Start with 0 LR

    def step(self, loss):
        """Update learning rate based on current loss."""
        self.global_step += 1

        if self.global_step <= self.warmup_steps:
            # Linear warmup
            warmup_lr = self.initial_lr * (self.global_step / self.warmup_steps)
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = warmup_lr
        else:
            # Check if loss increased
            if loss > self.best_loss:
                for param_group in self.optimizer.param_groups:
                    new_lr = max(param_group["lr"] / self.decay_factor, self.min_lr)
                    param_group["lr"] = new_lr
            self.best_loss = min(self.best_loss, loss)

    def get_lr(self):
        """Return current learning rates."""
        return [param_group["lr"] for param_group in self.optimizer.param_groups]

위 예시는 첫 warmup_steps epoch 동안 학습률을 0에서 서서히 증가시킨다. 이후에는 loss를 확인해, loss가 수렴하지 않을 때 학습률을 10씩 나눠가며 조정한다.

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = WarmupScheduler(optimizer, warmup_steps=1000, initial_lr=1e-3)

# Train
for epoch in range(epochs):
    for batch, label in dataloader:
        output = model(batch)
        loss = compute_loss(output, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Update learning rate based on loss
        scheduler.step(loss.item())

    print(f"Epoch {epoch}: LR = {scheduler.get_lr()}")