본문 바로가기

AI/논문리뷰

[논문리뷰] FixMatch: Simplifiying Semi-Supervised Learning with Consistency and Confidence

논문 링크 : https://arxiv.org/abs/2001.07685

 

FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

Semi-supervised learning (SSL) provides an effective means of leveraging unlabeled data to improve a model's performance. In this paper, we demonstrate the power of a simple combination of two common SSL methods: consistency regularization and pseudo-label

arxiv.org

 

해당 논문은 Semi-supervised learning에 대한 method 중 하나인 FixMatch에 대한 논문이다. Google Reserach에서 집필했으며, NeurIPS 2020에 실렸다. 당시에 SOTA 급의 performance를 보여주었으며, method가 매우 simple하다는 장점이 있다.

 

SSL에서 unlabel을 학습하는데 자주 사용되는 방법은 크게 두가지로, arificial label을 이용한 pseudo-labeling과 consistency loss가 있다. consistency loss는 unlabeled data의 raw data와 augmentation한 데이터의 softmax probability값의 차이를 줄여주는 방법으로 학습을 진행하는 것이다.

 

-Method

 

Method의 기본적인 idea는 위에서 설명한 두 가지 방법을 동시에 진행하는 것이다. 위의 그림과 같이 먼저 unlabel 데이터에 대해 weak agumentation과 strong agumentation을 취해준다. weak agumentation은 기본적인 filp, rotation등의 augmentation을 사용하였고, strong agumentation은 사람도 잘 알아보기 어렵게 강하게 aguemntation을 취한것인데, 여기서는 RandAugmentation과 CTAgumentation 이라는 기법을 사용하였다.

 

그 다음, agumentation 된 두 이미지를 동일 한 Model에 forward시킨다. 그러면, softmax를 통과한 predction value값이 output으로 나올 것이고, weak agumentation을 통과한 output에서 가장 probability가 높은 값의 class를 pesudo-labeling해준다. (one-hot-vector로 만들어준다고 생각하면 된다.) 기존의 UDA나 consistency에서 사용했던 sharpening 기법과 동일하게 엔트로피를 낮춰주는 행위와 동일하다고 한다. 그렇게 만든 pseudo-label을 strong augmentation의 output으로 나온 값과 CrossEntropy연산을 하여 Loss term을 둔것이 FixMatch이다. 

 

각 프로세스를 수식으로 살펴보면 다음과 같다.

 

supervised loss (labeled data)

 

unsupvervised loss (unlabeled data)

알파 : weak agumentation, A : strong agumentation, H: CrossEntropy

 

첫번째 수식은 labeled data를 학습할때 사용하는 loss term이다. pb는 label one-hot vector이고, pm은 weak aguementation한 prediction probabilty이다. 기존의 supervised와 동일하게 one-hot vector와 model을 통과한 probability에 대한 CrossEtropy를 loss term으로 사용한다.

 

두번 째 수식은 unlabeled data를 학습할때 사용하는 loss term으로 1로 표시된 함수는 indicator function이다. 함수 인자를 보면 qb의 max값을 취하는데, qb는 unlabel 데이터의 weak agument를 취하여 모델을 통과시킨 output 값이다. 따라서 해당 output probability value에서 가장 높은 값의 label을 선택하는 함수라고 할 수 있다. 옆의 타우는 threshold 값으로, output의 high probabilty가 특정 threshold를 넘어야 pseudo labeling을 하도록 설계한다. 보통 threshold 값은 0.8 ~ 0.95 값을 사용한다고 알고있다.

 

그렇게 구한 pesudo label인 one-hot vector가 q헷b이다. 즉 해당 loss term은 pesudo label한 vector와 strong agument output값의 CrossEtropy를 구한다.

 

최종 final loss는 supervised loss와 unsupervised loss를 더한 위와 같은 값을 사용하는데, 이때 람다는 0~1 사이 값으로, 일반적인 SSL에서는 처음에는 0의 값을 가졌다가 학습을 하면서 점점 값을 높여주는 방식으로 조절한다. 하지만, FixMatch에서는 람다 값을 따로 조절할 필요가 없다. 이유는 결국 람다 값은 unlabeled의 영향력을 결정하는 역할을 하는 것인데, FixMatch는 unsupervised loss에 있는 threshold값이 그 역할을 하기 때문이라고 해석했다. 예를 들어 학습 초기에는 모델이 불안정하기 때문에 unlabel data에 대해서는 대부분 max(qb)값이 threshold 값을 넘지 못할 것이다. 하지만 학습을 하면할수록 threshold값을 넘는 unlabeled data가 많아질 것이기 때문에 학습을 함에따라 람다 값을 따로 조절할 필요가 없는 것이다.

 

- Augmentation in FixMatch

 해당 논문에서는 Strong Augmentation과 Weak Augmentation을 사용하며, Weak Augmentation으로 사용하는 기법의 예시로는 standarad flip and shif (horizontally, vertically)를 사용했다. RandAgument, CTAugment, Cutout 등을 사용했다.

 또한 추가적으로 Mixup, adverserial perturbation과 같은 다양한 Augmentation을 Strong Augmentation으로 적용할 수 있다는 점이 해당 method의 또다른 장점이라고 한다.

 

- Experiment

아래와 같은 Table로 실험결과가 나타났다.

 실험결과를 보면, CIFAR-10 전체적으로 최근 성능이 가장 좋은 ReMixMatch의 성능을 이겼으며, CIFAR-10에서 SOTA를 달성하였다. 또한 label data가 40개밖에 없는 극한의 상황에서도 잘 working하는 것을 실험을 통해 알 수 있다. CIFAR-100에서는 FixMatch의 성능이 더 안좋은 경우가 발생하였지만, 저자는 ReMixMatch보다 FixMatch가 알고리즘 적으로 훨씬 더 simple하다는 점을 어필한다. (알고리즘의 simple하여 상대적으로 hyperparameter가 적다.)

 일반적으로 FixMatch에서 Strong Agument로 RA를 사용한 경우가 성능이 좋은 경우가 대부분이었지만, 상황에 따라 CTA의 성능이 더 좋은 경우도 있었다. 

 

- Ablation Study

 hyperparameter인 Threshold는 unlabeled data의 quality와 quantity를 결정하며, 둘은 trade off 관계이다. 즉, Threshold값이 높으면, 기준을 충족하지 못하는 unlabeled data의 수가 많아질 것이기 때문에 양은 줄어들지만, 해당 조건을 통과한 unlabeled data는 높은 질의 정보를 가질것임을 직관적으로 이해할 수 있다.

 위의 설명처럼 Sharpening 기법은 Pseudo label의 soft version이라고 한다. 위의 예시와 같이 UDA, MixMatch, ReMixMatch 등에서 Sharpening 기법을 사용하며, 이는 class간의 entropy를 줄이는 과정이므로, pesudo label 처럼 entropy를 최소로 만들지는 않지만 이러한 행위의 soft version이라고 할 수 있는 것이다.

 위의 (a)와 (b) 두 그래프가 있다. 빨간색 점선은 FixMatch의 lower bound 성능이다. (a)를 통해서 threshold 값이 커질수록 Error rate가 작아지는 경향성을 확인할 수 있다. 즉, threshold 값이 커지는 것이 좋다는 말이다. 다시말해 Quantity보다 Quality가 중요하다는 것을 알 수 있다.

 (b)에서는 shapening에서 사용하는 temperture 값을 여러값으로 조정하여 실험 결과를 확인한다. temperture, threshold값을 바꾸면서 실험한 결과 어떤 값이라도 lower bound인 (pseudo label을 사용한)FixMatch의 성능을 이길 수 없었다. 즉, Pesudo Label이 Sharpening보다 더 좋다는 것을 의미한다.

 

이상 논문 리뷰를 마칩니다 :)