본문 바로가기

네이버 부스트캠프 AI Tech

[CLIP] Learning Transferable Visual Models From Natural Language Supervision

CutMix로 날먹하기에 살짝 양심에 찔려, 다른 논문 한편을 추가로 준비하고자 한다.

 

CLIP Contrastive Language-Image Pre-treaining

 

들어가기 앞서- zero-shot이란?

모델이 학습 과정에서 배우지 않은 작업을 수행하는 것. →General한 task 수행

 

[ 배경 ]

이미지에 대한 날것의 텍스트로 부터 학습하는 것은 어려운 일이다. → pre-training된 task로 이미지에 맞는 caption을 예측하고, pre-training이후 visual concept에 학습한다. 

 

NLP는 웹에 있는 수많은 사람들이 라벨링한 데이터셋이 존재한다. 하지만 Vision의 경우에는 web-scale이 아닌, ImageNet과 같이 대량으로 라벨링된 데이터가 표준이다. →이미지도 NL처럼 웹에서 가져올 수는 없을 까?

 

그동안 NL을 Image표현에 사용하려는 많은 시도가 있었으나, zero-shot setting에서는 문제가 있었다. 또, 성능이 좋을지라도, 1000개에서 18291개의 클래스로 섬세하게 모델을 디자인 했기 때문에 좋은 성능을 낼 수 있었다. 대부분 softmax를 사용하며, dynamic output에 대한 방법으로는 부족했다. soft-max는 유연성을 감소시키고, zero-shot의 한계를 가져왔다.

 

[ 기존 연구와 차이점 ]

Image representations의 양의 차이. CLIP은 400 million pairs를 만들었다. 

OCR, geo-localization, action recognition 등 많은 일을 할 수 있다.

또한, 30개의 존재하는 데이터셋을 zero-shot으로 실험한 결과, 해당 데이터셋으로 학습된 모델과 경쟁할 수준이었다.

 

[ 접근 방법 ] - Natural Language Supervision

연구의 햇김은, NL에서 supervision을 통한 perception을 이용했다. "gold label"과 같이 라벨링이 필요한 이미지와 달리, NL은 인터넷에 양도 많기 때문에, 잠재적인 단어들에 대한 잠재적인 strength가 있었다. 이 NL이, zero-shot transfer가 유연하게 동작하도록 연결시켜주었다. 

 

- 개인적인 이해 - 

CLIP의 방법은 이전의 기억을 끄집어 내자면, "The photo of ?" 와 같은 방식으로 query를 날리면, 그에 해당하는 percentage를 return해 줬던것으로 기억한다. CLIP에서 NL을 사용하면 좋은 점에 대해서는, 자연어에서 얻을 수 있는 방대한 양의 데이터와, 또 그 단어들 사이의 연관성까지 고려할 수 있기 때문에 zero-shot transfer가 가능한 것이 아닌가 하는 생각이 든다.

 

[ 접근 방법 ] - Creating a Sufficiently Lager Dataset

MS-COCO나 YFCC100M과 같은 데이터셋이 있었으나, 영어로 Title이 있거나 descriptions만 있는 이미지를 추린 결과 ImageNet과 비슷한 양의 사진들이 나왔다.

NLP는 인터넷의 방대한 양의 데이터로부터 왔기 때문에, 이 데이터셋으로는 조금 부족했다. 따라서 인터네셍서 400 million의 (Image, text) pair을 만들었다. 

 

[ 접근 방법 ] - Selecting an Efficient Pre-Training Method

초기 접근: image CNN과 text transformer 동시에 학습. → transformer based language model이 zero-shot ImageNet classification에 약함을 확인. → 이유는 각 이미지에 대한 정확한 다음 단어를 예측하고자 했기 때문. 이를 해결코자, bag-of-words 인코딩을 통해, 학습 효율을 증진시켰다.

 

- 개인적인 이해 -

"The photo of" 와 같이 단어를 묶어, 이미지에 대한 가장 정확한 단어를 계속 예측하는 것이 아닌, 전체적인 이미지를 표현할 수 있는 단어를 뱉도록 NL encoder를 학습시킨것이 아닌가 생각한다.

N개의 (image, text) pair의 batch들을 가지고 학습을 진행한다.

CLIP은 image encoder과 text encoder에서 나온 multi-modal embedding space를 이용해 학습을 진행한다.

먼저 NxN 개의 가능한 pairing들을 확인한다.

embeding 결과가 서로 일치하는 N개의 pair들의 cosine 유사도는 최대화 하고,

일치하지 않는 N^2 - N 개의 pair들의 cosine 유사도는 최소화한다.

 

대용량의 데이터셋을 사용했기 때문에, over-fitting은 주된 문제가 아니다.

또, 사전 weight를 사용하지 않고 처음부터 학습시켰다.

마지막으로, embedding space로 projection할 때 non-linear projection을 사용하지 않고, linear projection을 사용했다. →non-linear과 linear의 큰 차이를 못 느꼈기 대문.

 

[ 접근 방법 ] -Choosing and Scaling a Model

Image Encoder로 두개의 구조를 고려.

1. ResNet50 → 널리 성능이 증명되었기 때문, 구조에서 ResNet=D, rect-2 blur pooling, global average pooling layer를 attention pooling mechanism으로 바꾸는 등, 꽤 많은 변화를 주었다.

2. ViT → 구조를 조금만 변경. 추가적인 transformer 전 patch와 position embedding에 normalization layer를 추가하고, 조금 다른 initialization scheme을 사용했다.

 

Text Encoder

Transformer → 8개의 attention head가 있는 모델을 base. 계산 효율을 위해, max sequence length를 76으로 잡았다. Masked self-attention가 text encoder에 사용되었다.

 

이전 연구: 규모가 큰 모델의 경우 모델의 width나 depth를 증가시켰다.

CLIP: ResNet의 경우 width, depth, resolution 모두 증가시켰다. → 확장시킨다면 성능이 좋다는것을 확인

Text encoder의 경우 ResNet의 width에 맞춰 width만 확장시켯다. → text encoder의 capacity에 CLIP의 성능이 크게 변하지 않음을 확인

 

[ 접근 방법 ] - Training

5개의 ResNet과 3개의 ViT사용

ResNet: ResNet-50, ResNet-101, EfficientNet-style model 3가지(ResNet의 연산량의 4x, 16x 64x -의역)

ViT: ViT-B/32, Vit-B/16, ViT-L/14

epoch: 32 epoch.

optimizer: Adam - weight decay regularization, gain이나 bias가 아닌 모든 weight에 적용.

decay: cosine schedule로 learning rate 조절

Initial hyper-parameter: grid searches, random search, manual tuning on baseline ResNet-50 model ( for 1 epoch )

이후 heuristically 적용됨.

Training 가속화와 메모리 절약을 위한 방법

minibatch size: 32,768 

gradient checkpointing, half-precision Adam statics, half-precision stochastically rounded text encoder weights

 

가장 큰 ResNet 모델인 RN50x64의 경우 학습에 18일 소요 (592대 V100 GPU)

가장 큰 ViT 모델의 경우 12일이 걸림(256대 V100 GPU)

ViT-L/14에서 성능을 향삭시키기 위해, 336 pixel resolution으로 1 epoch 추가로 학습.