Machine Learning/Learning Method

[논문] Align Representations with Base: A New Approach to Self-Supervised Learning (CVPR 2022)

진성01 2024. 8. 28. 16:24

이 글에서는 2022 CVPR에 게재된 "Align Representations with Base: A New Approach to Self-Supervised Learning" (Shaofeng Zhang et al) 논문을 정리한다.

 

논문 제목을 읽어보면 'A New Approach to Self-Supervised Learning'이라고 해서 새로운 SSL 방법론에 대한 논문이라고 생각할 수 있지만, 정확히는 Positive pairs만을 이용하여 학습하는 contrastive learning(대조 학습) 기법의 새로운 접근법이라고 이해하면 될 것 같다.

 

Contrastive Learning?


contrastive learning에 대해서 간단하게 소개하면, 하나의 이미지에서 각각 다른 augmentation 한 쌍을 생성한 후, 두 샘플은 결국 의미론적으로 같으므로 두 augmentation된 image를 postive pairs로 취급하여 embedding의 유사성(similarity)을 높이는 방향으로 학습을 진행하는 것이다.

 

Background


위에 설명했듯, 이 논문은 contrastive learning에 새로운 방식으로 접근한 연구이다. 따라서 들어가기에 앞서 현재까지 연구되어 온 contrastive learning의 background를 먼저 알아보자.

Negative-requiring methods

Negative-requiring methods는 contrastive learning에서 positive pairs(긍정적 샘플 쌍)와 더불어 negative pairs(부정적 샘플 쌍)을 사용하는 방식을 일컫는다. postive pairs는 similarity를 높이는 방향이라면, negativer pairs(다른 이미지에서 생성된 embedding)는 similarity가 낮아지도록 학습을 진행한다. 이 방식은 dimensional collapse(후술)를 방지하는 대신, 계산 복잡도가 qudratic(O(d^2))인 단점이 있다.

  • SimCLR: 가장 대표적인 contrastive learning 기법이다. postive, negative pairs를 이용한다.
  • Barlow Twins: cross-correlation matrix를 이용하여 loss를 계산하는 방식이다.

Negative-free methods

Negatie-free methods는 negative pairs를 사용하지 않는다. 이로 인해 생길 수 있는 문제인 dimensional collapse는 연구마다 다른 방법을 이용해 해결하였다.

  • BYOL: negative pairs가 없어서 생기는 학습의 불안정성으로 online network와 target network의 도입으로 해결함
  • DINO: distillation 기법을 이용하였으며, cluster separation을 이용해 nagative pairs를 대체하였다.

 

Introduction


논문에서는 기존 contrastive learning의 두 가지 문제점을 제시한다.

 

첫 번째 문제는 Negative requiring methods의 경우, 계산 복잡도가 O(d^2)라는 것이다(quadratic). 예시로 Barlow Twins method를 들 수 있다. Barlow Twin method는 위에서 설명했듯이 correlation matrix를 이용해 loss를 계산한다(위의 식에서 C가 correlation matrix). 따라서 backpropagation때 quadratic complexity를 가지게 된다. 이는 계산상의 비효율성 문제를 야기한다. 


그렇다면 negative-free methods를 사용하면 어떨까?

 

두 번째 문제는 Negative-free methods를 사용할 시 dimensional collapse문제가 야기된다는 것이다. 만약 negative pairs가 없이 positive pairs만을 사용한다면 위의 그림에서와 같이 서로 다른 방향을 가리켜야 하는 두 임베딩이 같은 방향을 가리킬 수 있게 된다. negative requiring의 경우 negative pairs로 이러한 문제를 해결할 수 있지만 nagative-free는 이 문제를 해결하기 어렵다. 

 

이 논문에서는 negative-free methods를 이용하면서도 dimensional collapse문제를 해결할 수 있는 방법을 제안한다.

 

Align Representations with Base


Overview

ARB의 전체 구조

위의 그림이 논문에서 제시한 ARB의 전체 구조이다. 천천히 살펴보자.

  • 먼저 기존의 contrastive learning과 동일하게 하나의 배치에서 서로 다른 두 방식으로 augmentation을 수행한다. 이를 통해 X^A, X^B를 생성한다. 그 후 각각의 X에 동일한 Encoder와 MLP를 통과시켜 일차적으로 임베딩Z를 뽑아낸다.
  • Random Shuffle이라는 부분은 계산량을 줄이기 위해 들어간 부분인데, 이 부분은 모델 전체 구조와는 크게 상관이 없으니 이후에 다루고 여기선 넘어가자.
  • 생성된 임베딩 Z의 closest base를 각각 구하고, 이를 B라 칭한다. 그 후 Z^A는 B^B를 따라가도록, Z^B는 B^A를 따라가도록 loss를 설계한다. 
    • 이 부분이 ARB의 핵심인데, 기존의 contrastive learning의 경우 Z^A와 Z^B가 서로 가까워지도록(positvie pairs) loss가 설계되었다면 ARB는 Z^A와 B^B 즉, Z^A와 Z^B의 가장 가까운 직교 벡터를 따라가도록 loss가 설계되었다는 것이다. 

즉 ARB의 핵심은 Z의 closest base(가장 가까운 직교벡터) B를 구하고, closs align을 통해 학습시킴으로써 dimensional collapse 문제를 해결한다.

 

어떻게 closest base를 구하는지, 그리고 이러한 방식이 어떻게 dimensional collapse문제를 해결하는지 살펴보자.

 

Nearest Orthonormal Basis (NOBS)

Z는 N의 batch size와 d의 feature vector dimension을 갖는 임베딩 batch이고, B0는 이러한 base matrix 즉, 모든 행 벡터들이 서로 내적값이 0이 되는 matrix이다(다시 말해 모든 행 벡터가 직교). 그렇다면 NOBs(Z에서 가장 가까운 base)는 다음 식과 같이 M(Z)로 정의할 수 있다.

 

M(Z)의 정의를 보면, (Z-B0)^2이 가장 작아지게 하는 B0이므로 Z에서 가장 가깝게 하는 조건이라고 할 수 있다. 이 식의 해를 구하면 아래의 식과 같이 정의 할 수 있는데 증명은 논문에 여러 줄로 설명되어 있으니 여기서는 다루지 않는다.

 

결과를 보면, M(Z)는 Z와 ∑(correlation matrix of Z)의 역행렬로 구성된다는 것을 알 수 있고, 이를 행렬분해 하면 다음과 같이 eigenvalue matrix와 eigenvector matrix의 행렬곱으로 표현 될 수 있다(이 부분은 선형대수학에서 배울 수 있다). 

 

Non-full rank cases

※이 section은 선형대수학 지식이 없다면 건너뛰자

 

선형대수학을 열심히 공부했다면 위의 정리에서 의문점을 하나 발견할 수 있다. 바로 ∑의 역행렬이 존재하지 않는다면 M(Z) 즉, closest base를 구하는 게 불가능하다는 것이다.

  • 실제로 ∑는 ZTZ이므로 ∑의 rank는 Z의 rank와 동일하다.
  • 그런데 Z는 N*d의 행렬이고, 따라서 batchsize인 N이 Z의 rank가 된다.
  • 그 말은,  d>N이라면, Z의 full rank는 d가 되고, Z의 rank는 N이므로 Z는 full rank가 아니다!
  • 그러므로 ∑도 full rank가 아니게 된다.
  • ∑가 full-rank가 아니므로 ∑는 역행렬이 존재하지 않는다.

즉 결론은 d>N이면 ∑의 역행렬을 구할 수 없다는 것이다. 

이러한 case를 나타내 주는것이 위의 그림이다. 왼쪽 행렬처럼 full rank가 아닐경우 대각 성분 중 0이 존재하게 된다. 이 경우 역행렬을 구할 수 없는데 논문에서는 λI를 더해주어 psuedo base를 구하는 technique을 사용하였다.

 

ARB Loss

ARB Loss는 결과적으로 다음과 같다. Z^A와 B^B의 행렬곱(즉, similarity of embeddings)이 커질수록 loss값이 낮아지도록 설계되었다. Z^A와 B^B가 같은 방향을 가리킬 때 loss는 가장 낮게 된다. 

 

Why can minimizing LARB avoid collapse?


그렇다면 왜 이런 방식이 dimension collapse를 피할 수 있는지 알아보자.

가장 나쁜 case로 서로 다른 image에서 추출된 embedding(즉, 서로 다른 방향을 가리켜야 하는 두 임베딩 벡터) zi, zj가 동일하다고 가정하자. 이 경우 LARB loss의 gradient는 어떻게 될까?

 

위의 그림을 보면, LARB를 Zi에 대해 미분한 값이 나온다. 마찬가지로 LARB 를 Zj로 미분할 수도 있을 것이다. 

그 후 그 아래 수식처럼 두 미분값을 내적하자. 이 값은 무조건 0일 것이다. 왜냐하면 각 미분값에는 Bi^B, Bj^B가 존재하고 이 둘은 서로 직교 벡터라는 것을 B의 정의를 통해 알고있다(B자체가 closest base니까 당연한 말). 그렇다는 말은, Zi의 gradient와 Zj의 gradient가 서로 정 반대방향을 가리키고, 이 둘의 합은 0이 아니라는 것이므로, 그 다음 parameters update때는 두 벡터를 직교방향으로 떨어뜨려 놓으려고 할 것이다!

 

이러한 방식으로 LARB의 dimensional collapse를 막을 수 있다는 것을 증명한다.

 

Reduce computation complexity


위의 overview에서 설명할때 skip했던 Random shuffle을 이제 살펴보자.

 

Nearest Orthogonal Basis (NOB)를 구하는데는 O(d^3)의 time complexity를 가진다. 이는 back propagation에서 일어나는 연산이 아님에도 불구하고 꽤나 큰 연산량이다. 따라서 논문에서는 차원 크기 d의 벡터를 p개로 자른후 각각의 NOB를 구하는 방식을 채택한다. 위의 그림은 p=3인 경우를 나타내고 있다. 이 방식을 적용할 경우 연산량을 O((d/p)^3)으로 줄일 수 있다.

 

그러나 단점도 존재한다. p개로 분할하지 않을 경우 완벽하게 서로 orthogonal한 벡터를 찾을 수 있지만 p개로 자르면 다시 합쳤을 때 분할된 차원 끼리만 orthogonal을 보장할 수 있다. 위의 그림에서 보면, G1A부분의 3개의 벡터 (초록, 파랑, 주황)은 서로 orthogonal을 보장할 수 있지만 모두 합쳐진 전체 벡터는 orthogonal이 아니다. 

 

이 문제를 일부 완화하기 위해 차원을 random으로 shuffle하는 전략을 선택한다. 만약 shuffle하지 않고 고정된다면 G1A, G2A 두 집단의 각도는 항상 orthogonal하지 않은 방향으로 학습될 것이다. 하지만 학습때마다 shuffle을 하면 모든 dimension측면에서 orthogonal하도록 보완할 수 있다.

 

결론적으로 ARB의 computational complexity는 다음과 같아진다.

 

Barlow Twins(negative requiring)과 ARB의 computational complexity 비교

 

Experiments


ARB가 downstream task에서 좋은 성능을 나타내는 것을 알 수 있다.
Baseline과의 비교. embedding distance와 embedding std를 눈여겨볼 만 하다.
Other methods와의 비교. Batch size와 output dimension에 대해 강건하다.