- 논문: arXiv
- 공식 구현: Pytorch-vision
- 분석 코드: Github
본문에 L000
으로 적힌 링크는 줄번호로, 클릭하면 Pytorch에서 어떻게 구현되어 있는지 확인할 수 있다.
Abstract
Transformer는 자연어 처리 분야에서 활발히 사용되고 있지만, 비전(vision) 문제에 적용된 경우는 제한적이다. 우리는 이미지 조각을 순수한 transformer에 입력해 분류 문제를 풀었다. Vision Transformer(ViT)는 CNN과 비교해 SOTA를 달성했으며, 더 적은 연산 비용이 든다.
Introduction
Self-attention 구조의 transformer가 자연어 처리에서 좋은 성능을 보이고 있지만, 비전 분야는 여전히 CNN이 우세하다. 이로 인해 ResNet 기반의 모델이 SOTA를 보이고 있다.
우리는 자연어 처리에 영감을 받아 기본 transformer에 이미지를 넣어봤다. 이미지는 조각으로 나누어져 일련의 선형 임베딩으로 입력된다. 이미지 조각은 자연어 처리에서 단어 토큰과 같이 다루어진다.
Transformer는 중간 사이즈의 데이터를 학습했을 때 ResNet보다 낮은 정확도를 보이는데, CNN과 달리 inductive bias가 부족하기 때문으로 보인다 (translation equivariance, locality 등). 따라서 충분한 데이터가 없다면 쉽게 일반화되지 않는다.
하지만 큰 데이터셋을 학습할 때는 Vision Transformer(ViT)가 좋은 성능을 보인다. 다음은 데이터셋 별 모델 정확도이다.
- ImageNet: 88.55%
- ImageNet-ReaL: 90.72%
- CIFAR-100: 94.55%
- VTAB(19-task): 77.63%
Method
Vision Transformer (ViT)
Transformer는 일련의 1D token embedding을 입력으로 받는다. 우리는 이미지를 일련의 2D patch로 나누어 사용한다.
Transformer는 정해진 크기의 latent vector를 가지기 때문에 이미지 patch가 정해진 차원으로 매핑될 수 있도록 한다.
$$z_0 = [x_{class};x^1_pE;...;x_p^NE]+E_{pos}$$
BERT와 마찬가지로 [class]
토큰은 학습 가능한 임베딩 벡터($x_{class}$)로 encoder를 거쳐 출력으로 나간다. Classification Head는 1-layer MLP로 구현한다 (L243).
CLS(class) 토큰은 첫 번째 임베딩 벡터로 학습 가능한 랜덤한 값으로 초기화된다 (L220). 이 토큰은 학습 과정에서 encoder 내 모든 이미지 조각의 정보를 반영하며, 이미지를 대표하는 값을 갖게 된다. 이후 encoder 출력으로 나가 분류 문제를 푸는데 활용한다 (L301).
Position embedding은 patch에 더해진다. 학습 가능한 1D 임베딩을 사용하며, 2D-aware 방식과 큰 성능 차이를 발견하지 못했다.
Encoder는 multihead self-attention과 MLP block으로 만들어진다. 정규화를 모든 블록 전에 추가하며 모든 블록 뒤에 residual connection(L115)을 적용한다.
(encoder_layer): EncoderBlock(
(ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
)
)
ViT는 CNN처럼 이미지에 특화된 inductive bias가 없다. 따라서 position embedding을 이용해 위치 정보를 조정하며, patch의 2차원 위치 정보를 처음부터 학습해야 한다.
대안으로 CNN의 feature map을 입력으로 사용하는 방법이 있다 (L213). 이러한 방법을 hybrid
라고 표현한다. 만약에 1x1 필터를 사용하면 이미지를 feature map 차원으로 flatten하는 과정이 된다.
Fine-tuning and Higher Resolution
우리는 ViT를 큰 데이터셋에 사전 학습시키고, 작은 downstream 문제에 fine-tune 했다. 이때 사전 학습된 prediction head를 지우고 0으로 초기화한 feedforward layer를 추가했다. 고해상도 이미지를 처리할 때도 patch 크기를 유지했으며, 시퀀스 길이는 길어진다. 하지만 transformer는 고정된 길이를 입력 받기 때문에 사전 학습된 position embedding에 2D interpolation을 적용해 사용한다.
Experiments
Setup
ViT는 BERT 기본 설정을 활용한다. 또한 크기에 따라 다음과 같은 표기법을 사용한다. 예: ViT-L/16 = "Large" variant with 16 x 16 patch. Patch 크기가 작을수록 많은 연산을 수행한다.
Baseline CNN으로 ResNet을 변형해 사용하며, 이 모델을 BiT
라고 표기한다.
학습에 사용한 설정은 다음과 같다.
- Adam($\beta_1=0.9$, $\beta_2=0.999$)
- batch size: 4096
- weight decay: 0.1
- linear learning-rate warmup
Fine-tuning은 다음과 같다.
- SGD + momentum
- batch size: 512
Comparison to State Of The Art
작은 모델인 ViT-L/16이 BiT-L을 앞선다. 심지어 이전 SOTA보다 연산량도 적다.
Pre-training Data Requirements
데이터 크기가 얼마나 중요할까? 작은 데이터(ImageNet)를 학습한 ViT-Large는 ViT-Base보다 낮은 성적을 보인다. 큰 데이터(JFT-300M)를 학습했을 때 큰 모델이 좋은 성능을 보였다. 데이터가 작을 때 BiT CNN이 ViT보다 좋은 성적을 보이지만, 데이터가 커지면 그 반대가 된다.
작은 데이터에 대해 ViT는 ResNet보다 쉽게 overfit 되는 경향이 있다. 이를 통해 convolutional inductive bias는 작은 데이터를 학습하는데 유리하지만, 충분히 큰 데이터는 직접적으로 패턴을 분석하는 것이 유리하다는 사실을 추론할 수 있다.
Scaling Study
- ViT는 성능과 비용 측면에서 ResNet을 압도한다. ViT는 연산 비용이 약 2 ~ 4배 정도 적다.
- 데이터가 작을 때 Hybrid가 약간 더 좋은 성능을 보인다. 하지만 데이터가 커지면 차이가 없어진다.
- 아직 ViT는 포화(saturate) 상태가 아니기 때문에 후속 연구가 이어질 수 있다. (모델을 키우면 성능도 커질 것으로 기대한다.)
Inspecting Vision Transformer
첫 레이어는 이미지를 저차원으로 매핑시킨다. 위 이미지는 학습된 필터 중 PCA를 통해 찾아낸 주요 28개 필터 모습이다. 이미지 patch에서 구조를 찾아내기 위한 모양으로 보인다.
이후 position embedding이 더해진다. 가까운 patch는 유사한 position embedding을 보인다. 위 이미지는 patch와 position embedding 간의 유사도를 2차원으로 나타낸다.
Attention 가중치를 바탕으로 어느 정도 깊이(거리)의 네트워크에서 전반적인 정보를 수집해 내는지 확인했다. 여기서 "attention distance"는 CNN의 receptive field 크기와 같다. 몇몇 head는 초기에 대부분의 정보를 잡아내기도 했다. 다른 head는 지속적으로 작은 attention distance를 보였다. 이렇게 강한 localized attention은 hybrid model에서 적게 나타났다. 이는 CNN이 지역적인 정보를 찾기 때문에 attention head에서 지역적인 패턴을 찾을 필요가 없기 때문으로 보인다. 다시 말해, CNN은 지역적인 정보를, Attention은 넓은 범위의 정보를 찾는데 유리하다고 볼 수 있다. 이러한 정보를 바탕으로 분류에 필요한 이미지 부분을 찾아낸다.
Attention 가중치를 이미지에 투영한 예시다. 강아지의 윤곽(귀, 앞발 등)에 강한 가중치를 주어 중요도가 높은 정보로 판단한다. 반면 뒤에 사람은 낮은 가중치를 준다. 따라서 분류 문제를 풀 때, 강아지가 있는 부분은 강하게, 사람이 있는 부분은 약하게 반영된다.
Self-supervision
BERT를 참고해 self-supervision을 위한 masked patch prediction을 수행했다. ViT-B/16을 기준으로, ImageNet을 이용해 바닥부터 학습하는 경우보다 2% 성능 향상이 있었지만, supervised pre-training보다는 4% 뒤쳐졌다.
Conclusion
Transformer를 이미지 인식에 바로 적용해봤다. Vision Transformer는 이미지 분류에서 SOTA를 뛰어 넘었으며, 상대적으로 비용이 적게 든다.
하지만 여전히 문제가 남아있다.
- ViT를 detection, segmentation 등 다른 문제에 적용
- self-supervised pre-training 방법 탐구
- 성능 향상을 위한 ViT 모델 크기 키우기