DeepLearning/Computer Vision

[Contrastive Learning] Contrastive Learning이란

yooj_lee 2022. 5. 8. 00:01
300x250


오늘은 contrastive learning에 대해 정리를 해보겠습니다. 처음에 facenet에서 triplet loss를 접하고 흥미 있는 분야라고만 생각해왔는데 self-supervised learning 분야에서 많이 발전을 이룬 것 같습니다. 해당 포스트에서 정리한 내용은 survey 논문인 "Contrastive Rerpesentation Learning: A Framework and Review"를 읽고 정리한 내용입니다. 하기한 내용에 질문 혹은 오류가 있을 경우 댓글 부탁드립니다.


Contrastive Representation Learning: A Framework and Review

포스트의 목차는 다음과 같습니다. 이번 포스팅에서는 2. Contrastive Learning Architecture Framework 부분까지만 정리하도록 하겠습니다.

  1. Contrastive Learning
  2. Contrastive Learning Architecture Framework
  3. Contrastive Loss
  4. Architectures of Contrastive Learning

Contrastive Learning

Contrastive Learning (이하, CRL)이란 입력 샘플 간의 비교를 통해 학습을 하는 것으로 볼 수 있습니다. CRL의 경우에는 self-supervised learning(자기지도학습)에 사용되는 접근법 중 하나 (물론, 지도학습의 맥락에서 CRL이 수행되기도 합니다)로, 사전에 정답 데이터를 구축하지 않는 판별 모델이라고 할 수 있습니다. (판별 모델: Discriminative Approach)

따라서, 데이터 구축 비용이 들지 않음과 동시에 학습 과정에 있어서 보다 용이하다는 장점을 가져가게 됩니다. 이러한 데이터 구축 비용 이외에도 label이 없기 때문에,

1. 보다 일반적인 feature representation
2. 새로운 class가 들어와도 대응이 가능

하다는 장점이 추가적으로 존재합니다.

이후, classification 등 다양한 downstream task에 대해서 네트워크를 fine-tuning시키는 방향으로 활용하곤 합니다. 

feature를 학습한 이후의 활용

 


Contrastive Representation Learning

Contrastive Representation Learning을 설명하기 앞서 Representation Learning을 설명하도록 하겠습니다. Representation Learning(표현 학습)은 크게 두 가지 접근법이 존재합니다. 하나는 생성모델의 측면, 나머지 하나는 판별모델의 측면입니다. 

각각의 접근법은 장단이 존재합니다.

생성모델로 데이터의 표현을 학습하는 경우, 비지도 학습이기 때문에 데이터 구축 비용이 낮다는 장점이 있습니다. 또한 저차원 표현을 학습하는 데에 있어 목적함수가 보다 일반적이라는 장점이 있습니다.

판별모델의 경우에는 계산 비용이 적고, 비교적 학습이 용이하다는 장점이 있습니다. 대부분 라벨링된 데이터에 의존하기 때문에 데이터 구축 비용이 크다는 장점이 있습니다. 비용이라 함은 단순히 금전적 비용을 떠나 시간적 비용, 민감한 정보 유출, 데이터 라벨링 작업자의 편견 개입 등 보다 포괄적인 의미입니다. (추가적으로 판별 모델의 경우에는 데이터가 속한 클래스를 판별하는 목적을 지녔기 때문에, 보다 지엽적인 목적함수라고 할 수 있습니다. 실제로 판별모델을 학습하는 과정 중 학습되는 representation은 texture에 보다 집중을 한다는 주장을 하는 논문 또한 발표되었습니다)

Contrastive Learning도 representation learning을 수행하기 위한 하나의 방법입니다. Contrastive Learning은 앞서 말했듯이 입력 샘플 간의 비교를 통해 표현을 학습하게 됩니다. 따라서, 목적은 심플합니다. 학습된 표현 공간 상에서 "비슷한" 데이터는 가깝게, "다른" 데이터는 멀게 존재하도록 표현 공간을 학습하는 것입니다.

같은 이미지에서 나온 이미지 패치는 positive, 다른 이미지에서 나온 이미지 패치는 negative

 

조금 자세히 말씀을 드리자면, 여러 입력쌍에 대해서 유사도(유사한지 아닌지)를 라벨로 판별 모델을 학습합니다. 이때 유사함의 여부는 데이터 자체로부터 정의될 수 있습니다 (즉, self-supervised learning이 가능합니다).

Contrastive 방법의 경우, 다른 task로 fine-tuning을 수행할 때에 모델 구조 수정 없이 이루어질 수 있다는 점에서 훨씬 간편합니다.

위에 기반하여 Contrastive Learning은 representation 학습에 있어서 간단하면서도 효과적이라는 장점이 있습니다. 

 


Example: Instance Discrimination Task

CRL 아키텍처의 하나인 Instance Discrimination Task에 대해 말씀드리겠습니다 (self-supervised learning에서 pre-text task가 Instance Discrimination이라고 생각하시면 쉽습니다). 이 Instance Discrimination Task는 Unsupervised Feature Learning via Non-Parametric Instance Discrimination(Zhirong Wu et al., 2018) (이하, InstDisc)에서 처음 제안되었습니다. 

Instance Discrimination Task의 경우, 다음과 같이 네트워크가 대략적으로 구성되고. 하나의 sample에서 두 가지의 view가 생성됨을 알 수 있습니다. 이때, 같은 이미지에서 나온(같은 인덱스에 위치한) pair는 무조건 positive pair이고, 그를 제외한 다른 인덱스 내의 view와는 모두 negative입니다. pair의 구성은 다음과 같이 이루어집니다.

contrastive learning의 pair 구성

 

instance discrimination을 위한 contrastive learning의 architecture는 다음과 같이 구성됩니다.

전체적인 구조

 

1) Data Augmentation을 통한 input pair 생성

 

쉽게 말해 같은 이미지에서 생성이 된다면 positive pair이고, pair 내 두 이미지가 다른 이미지로부터 나왔다면 negative pair 입니다. (위의 contrastive의 pair 구성 이미지를 확인하시는 게 가장 직관적입니다)

positive pair를 구성할 때는 원본 이미지에서 image transformation을 적용한 augmented image를 생성하여 pair를 구성하게 됩니다. 이때, augmentation (transformation)은 random crop, resizing, blur, color distortion, perspective distortion 등을 포함합니다. 

다양한 image transformation의 적용

 

2) Generating Representation (= Feature Extraction)

입력 이미지 쌍을 생성했다면 해당 이미지 쌍으로 representation 학습, 즉 특징을 추출해야 합니다. contrastive learning network 내 이와 같은 부분을 feature encoder $e$로 칭하고, $e$는 다음과 같이 특징벡터 $v$를 출력하는 함수로 표현될 수 있습니다.

$$e(.) → \mathbf{v}=e(\mathbf{x}), \mathbf{v}\in \mathbb{R}^d $$

이때, Encoder의 구조는 특정되지 않고 어떤 backbone network든 사용 가능합니다. 참고로 InstDisc에서는 ResNet 18을 사용했습니다.

 

3) Projection Head

projection head $h(.)$에서는 encoder에서 얻은 특징 벡터 $\mathbf{v}$를 더 작은 차원으로 줄여주는 작업을 수행합니다 (간혹 여러 representation을 결합하는 방식으로 projection을 수행하기도 합니다. 이 경우에는 contextualization head라고 참고 논문에서는 지칭하고 있습니다. 하지만, InstDisc에서의 projection head는 2048차원의 특징 벡터 $\mathbf{v}$를 128차원의 metric embedding $\mathbf{z}$로 projection해주는 용도, 즉 dimensionality reduction으로 사용되고 있습니다). projection head $h$는 다음과 같이 metric embedding $\mathbf{z}$를 출력하는 함수로 표현될 수 있습니다.

$$ h(.) → \mathbf{z}=h(\mathbf{v}), \mathbf{z} \in \mathbb{R}^{d'}, d' < d $$

projection head 같은 경우엔 간단한 MLP (Multi-layer perceptrons) 구조를 취합니다. 이후, unit vector로 정규화해줍니다. 

cf.) metric embedding
contrastive loss는 기본적으로 각 pair의 유사도를 측정하게 됩니다. (이러한 유사도는 거리가 될 수도 있고, pair 가 공유하는 entropy로 계산이 될 수도 있습니다만 이는 이후의 포스팅에서 더 자세히 다루도록 하겠습니다.) 유사도는 metric으로 나타낼 수 있고, 이에 loss에 input으로 들어가는 z를 metric embedding이라고 표현하는 것입니다. (따라서 projection head 내에서는 feature representation space에서 metric representation space로 projection했다고 볼 수 있겠습니다)

 

4) Loss 계산

앞서 contrastive learning의 목적(objective)은 positive pair의 embedding은 가깝게, negative pair의 embdding은 멀게하는 것이라고 말씀 드렸습니다. Loss는 이러한 objective를 직접적으로 수행하는 부분입니다. 이를 Contrastive Loss로 부르겠습니다. Contrastive Loss와 같은 경우에는 InfoNCE, NT-Xent 등이 많이 사용되고 있습니다.

  • i번째 입력쌍에 대한 Loss의 일반항

$$\mathcal{L} = - log \frac{exp(\mathbf{z}_i^T\mathbf{z'}_i/\tau)}{\sum_{j=0}^K exp(\mathbf{z}_i^T\mathbf{z'}_j/\tau)}$$

  • $\mathbf{z}_i^T\mathbf{z'}_i$: 두 벡터 $\mathbf{z}, \mathbf{z'}$의 내적. 이때, $\mathbf{z'}$는 $\mathbf{z}$의 변형(transformation;augmented $\mathbf{z}$)
  •  $\tau$는 하이퍼 파라미터로, 두 벡터 간의 내적이 전체 loss에 어느 정도 영향을 미치는지 조절
  • 분모의 sum은 $\mathbf{z}_i$에 대해서 하나의 positive pair와 $K$개의 negative pair에 대해서 계산된다.

 

5) 학습 완료 후

네트워크가 학습이 완료된 후에는 projection head 이후부터는 버리고 encoder만 transfer learning을 위한 feature extractor로 사용하게 됩니다. 이후, predictor를 뒤에 결합하여 새로운 task에 적용할 수 있도록 fine-tuning을 거치게 됩니다 (전형적인 pretext-downstream task 구조)

이 예시에서 보신 것과 같이 contrastive learning framework에서는 어떤 augmentation을 적용하느냐가 모델 성능에 큰 영향을 미치게 됩니다.

색상, 형태, edge 등 low-level의 시각적 단서에만 네트워크가 의존하여 표현을 학습하지 않도록, 이미지 전체가 담고 있는 추상적인 의미(image semantic)를 잘 파악할 수 있도록 다양한, 그러나 image semantic을 변화시키지 않는 augmentation을 적용하여 입력 이미지 페어를 구성하는 것이 중요할 것입니다.

 


용어 정리

앞서 instance discrimination task 사례에서 살펴봤듯이 contrastive learning framework에서는 input pair 생성, encoder, projection head, loss 등 다양한 모듈이 있습니다. 이러한 모듈에 덧붙여 survey paper에서 사용하는 용어 (혹은 concept)에 대해 정리하고 이 포스트를 마치겠습니다.

1. query, key

기준 벡터와 비교 벡터에 대해 query, key 비유를 사용했습니다 (사전에 비유를 했다고 생각하면 쉽습니다). 앞서 벡터라는 표현은 이미지, representation, metric embedding 이 세가지를 모두 아우르는 표현입니다.

 

2. Similarity Distribution

입력 샘플쌍의 결합 분포로 다음과 같이 표현할 수 있습니다.

$$p^+(q,k^+)$$

각 key는 similarity distribution (query와 비슷한 sample들의 분포)에서 추출하면 $k=k^+$가 되고, dissimilarity distribution(query와 비슷하지 않은 sample들의 분포)에서 추출할 경우 $k=k^-$가 됩니다. 

실제 학습 진행을 하면서 distribution을 직접 가정한다기보다는 어떤 식으로 input pair를 구성할 것인지를 결정하는 부분이라고 보면 됩니다. 앞서 InstDisc에서는 data augmentation을 통해서 같은 이미지에서 augmented되었다면 positive, 다른 이미지에서 augmented되었다면 negative라고 정의한 것처럼, 어떤 pair를 positive로 구성할 것인지 혹은 어떤 pair를 negative로 구성할 것인지에 대한 것입니다.

 

3. Model

파라미터(네트워크의 경우 가중치)가 존재하는 모든 모듈의 총체로, $f(x;\theta): \mathcal{X} → \mathbb{R}^{|\mathcal{Z}|}$로 표현될 수 있습니다. (input space $\mathcal{X}$에서 metric embedding $|\mathcal{Z}|$ 차원의 실수 공간인 $\mathbb{R}^{|\mathcal{Z}|}$로 매핑시키는 함수 $f$)

survey paper에서는 encoder와 transform head로 분리하여 설명하고 있습니다.

 

4. Encoder

입력 view에서 representation vector로의 mapping을 학습하는 부분입니다. 분류기의 입력으로 넣어주는 등 Encoder가 학습한 representation을 다른 모델의 입력으로 사용하기도 합니다 (이 경우에는 encoder 부분을 freeze하는 방식일 것입니다).
혹은 encoder 위에 layer를 쌓아 fine-tuning 시에 활용하기도 합니다.

 

5. Transform Head

feature embedding $\mathbf{v}$를 metric embedding $\mathbf{z}$로 변환하는 모듈입니다. 앞서 말했던 바와 같이 여러 representation을 결합하거나 contrastive loss에 넣기 전에 조금 더 차원을 줄이는 용도로 사용합니다.

 

6. Contrastive Loss

query, positive key, negative key로 구성된 metric embedding 쌍 $\{(\mathbf{z}, \mathbf{z}^+),(\mathbf{z}, \mathbf{z}^-)\}$에 적용됩니다.

embedding 간의 유사도를 측정하고 positive pair의 유사도는 높게 negative pair의 유사도는 낮게 강제하는 역할을 하게 됩니다. 유사도를 측정하는 부분(scoring function)과 실제 loss의 형태 (cross entropy, distance-based loss 등), 이 두 가지 부분으로 크게 나눌 수 있습니다.

위와 같이 학습된 representation은 positive pairs의 입력 공간 내에 존재하는 trivial noise에는 invariant해야 하며, negative pair 간의 차이를 설명하기 위한 covariant representation은 잘 반영하고 있어야 합니다.

 


결론

이번에는 contrastive learning은 무엇인지, 기존 representation learning 방법론에 대비하여 어떤 차이 혹은 공통점이 있는지, 그 장점은 무엇인지 등에 알아보고 Instance Discrimination Task를 푸는 사례를 통해 contrastive learning의 concept에 대해 간략하게 알아보았습니다. 다음에는 Contrastive Loss에 대해 정리를 해보겠습니다. 

 

Reference

Contrastive Representation Learning: A Framework and Review (PH Le-Khac et al., 2020)
Unsupervised Feature Learning via Non-Parametric Instance Discrimination (Zhirong Wu et al., 2018)

 

 

300x250