Notice
Recent Posts
Recent Comments
«   2026/01   »
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

No Limitation

분포 거리에 대한 개념적 고찰 (KL Divergence, JS Divergence, Wasserstein Distance) 본문

ML & DL & RL

분포 거리에 대한 개념적 고찰 (KL Divergence, JS Divergence, Wasserstein Distance)

yesungcho 2025. 4. 29. 21:04

본 포스팅은 분포 거리를 근사할 때 많이 활용되는 KL Divergence와 그것의 한계를 극복하기 위한 다양한 거리 측정 방법의 개념을 다루고자 합니다. 

 

특별히 JS Divergence와 Wasserstein Distance의 경우 GAN 포스팅에서도 다루었기 때문에 참고하시면 도움이 될 것 같습니다.

https://yscho.tistory.com/106

 

KL Divergence (Kullback-Leibler Divergence)

KL Divergence를 들어가기 전에 먼저 정보 이론을 이해할 필요가 있다. 데이터의 확률 분포를 $P$ 라고 할 때, 어떤 사건 $x$ 에 대한 발생 확률을 $P(x)$ 라고 한다면, 정보 이론에서는 이 사건을 통해 얻을 수 있는 정보량을 $ - \log P(x) $로 정의한다. $log$ 를 취한 것을 제외하고 음수가 붙는데 이는 사건이 더 자주 발생하는 경우 얻을 수 있는 정보량은 적음을 의미한다. 반대로 사건이 발생할 확률이 적은 경우는 사건이 발생하게 되면 그 사건이 의미하는 바가 크기 때문에 정보량이 더 많아지는 특징이 있다. 따라서 우리는 전체 확률 분포 $P$ 에서 특정 사건 $x$ 들에서 발생하는 평균적인 정보량을 아래와 같이 정의할 수 있다. 그리고 우리는 이것을 엔트로피 (Entropy), 혹은 자가 엔트로피 (Self-entropy) 라고 부른다. 즉, 이 Entropy는 "자기 자신의 데이터를 설명하는 데 드는 평균 정보량" 을 지칭한다.

 

자가 엔트로피 수식

 

그렇다면 만약 우리가 실제 데이터의 확률 분포가 $P$ 라고 할 때, 우리가 어떠한 모델을 통해 예측한 확률 분포를 $Q$ 라고 가정한다면 이들 사이에서는 차이가 발생할 수 있다. 이 때 우리는 다음과 같은 수식을 구축할 수 있다. 

교차 엔트로피 수식

 

즉 정보량을 구성하는 분포가 $Q$ 로 바뀌되, 데이터는 $P$ 에서 온 것이기 때문에 $P$ 에서 온 데이터를 $Q$ 를 통해 정보량을 계산한다. 우리는 이것을 교차 엔트로피 (Cross Entropy)라고 부른다. 즉, 이 Entropy는 "로부터 생성된 데이터를, Q 라는 분포로 설명하려고 할 때 드는 평균 정보량"을 지칭한다. 

 

그래서 이 교차 엔트로피는 분류 문제에서 softmax를 통과하고 나온 확률 값 P, Q를 근사시켜 모델이 예측한 분포가 실제 정답 분포 (보통 분류 문제의 경우 이산 형태의 one-hot 분포) 를 예측하게끔 하는 방법으로 동작을 수행한다. 

 

자 그렇다면 KL-Divergence는 무엇인가?

KL-Divergence 수식


바로 이 교차 엔트로피에서 자가 엔트로피 값을 뺀 값을 의미한다. 즉, 의미론적으로 "모델 $Q$ 를 사용해서 실제 분포 $P$ 를 설명할 때 드는 평균 정보량" 에서 "실제 분포 $P$ 를 자기 자신으로 설명할 때 드는 평균 정보량"을 뺀 값을 의미하며, 이는 "모델이 실제 데이터에 비해 추가로 발생시키는 정보량"을 의미한다. 만약에 Q가 굉장히 P와 유사하다면 모델이 실제 데이터에 비해 추가로 발생시키는 정보량은 거의 없을 것이다. 왜냐하면 P랑 유사할 것이기 때문이다. 하지만 P랑 차이가 있다면 정보량의 차이가 발생할 것이기에, 우리가 만약 Q를 P로 근사하고 싶다면, 바로 이 KL-Divergence를 최소화하게끔 Q를 찾는 것이 학습의 목적이 되는 것이다. 

 

그리고 이 KL-Divergence에서 중요한 특징은 "모델"이 "실제 데이터"에 비해 추가로 발생시키는 정보량을 의미한다고 했을 때, 위 수식이 의미하는 바와 일치한다. 하지만 반면 P와 Q의 위치를 바꾸어서 아래와 같은 구조를 상상해보자.

기존 KL-Divergence 수식
P와 Q의 위치가 변한 KL-Divergence 수식

 

즉 이는 의미론적으로 "실제 데이터"가 "모델이 알고 있는 지식"에 비해 추가로 발생시키는 정보량을 의미한다. 즉, 저 KL-Divergence를 업데이트한다는 것은 실제 데이터의 분포를 모델이 알고 있는 지식으로 분포시킨다는 의미가 된다. 하지만 실제 데이터의 분포는 변하지 않으며 우리가 학습을 시키는 것은 '모델이 알고 있는 지식'이기 때문에, 결국 엉뚱한 최적화를 수행하게 될 수도 있다. 즉, KL Divergence는 교환 법칙이 성립될 수 없음에 주의해야 한다

 

JS Divergence (Jensen-Shannon Divergence)

하지만, 이 KL-Divergence가 갖는 문제점이 하나 있다. 예를 들어 생성자가 내뱉는 생성 분포 $Q$ 에서 출발해서 데이터 분포 $P$ 를 근사시키는 GAN을 학습시킨다고 가정해보자. 이 상황에서 운이 안 좋게 GAN을 초반에 학습할 때 P와 Q가 전혀 다른 분포를 가지는 경우 (지지 집합이 매칭되지 않는 경우, support mismatch) $P(x)$는 양의 값을 가져도 $Q(x)$는 0에 근사한 값을 가질 수 있는데, 이 때 log 안의 수식에서 분모가 0에 가까워지므로 손실이 발산하는 문제가 발생한다. 즉, 이런 경우 학습이 안정적이지 않을 우려가 있어, 이를 보완할 필요가 있었다. JS Divergence는 이런 부분을 보완하는 데에 적합한 손실 함수 중 하나이다. 수식은 아래와 같다.

JS-Divergence 수식


위 수식을 보면 JS-Divergence는 P와 Q의 중간 분포인 M을 사용한다. 그리고 이를 $KL(P||M)$ 과 $KL(M||P)$ 을 각각 계산해서 평균을 취한다. 이 때문에 JS-Divergence는 KL-Divergence와 다르게 대칭성을 가지게 된다. 이는 JS-Divergence가 metric으로 사용하기가 더 편리한 장점을 가짐을 시사한다. 또한 저 loss의 경우 아무리 P와 Q가 극명하게 달라도 upper bound로 $log2$ 값을 가지게 된다. 이해를 위해 둘의 극단적인 케이스를 고려하면, 예를 들어 P(x)=1, Q(x)=0이라고 하면 M(x)=1/2가 된다. 이를 JS 수식에 넣게 되면 JS 값은 $log2$ 가 된다. 즉 아무리 P와 Q가 차이가 많이 나도 손실 값이 발산하지 않게 되며 더 안정적인 학습이 가능하도록 만들어준다. 이것이 JS-Divergence가 사용되는 이유다. 

 

Wasserstein Distance 

하지만 저 upper bound는 다른 문제를 낳게 된다. 예를 들어, 다음과 같은 그림을 생각해보자.

파란색이 실제 데이터의 분포라고 할 때, 빨간색 분포는 generator가 형성한 생성 분포를 의미한다. 이 때 빨간색 분포에서 모여있는 일부 샘플들이 discriminator를 속이는데 성공한다면, 저 나머지 real data의 distribution은 학습에 크게 기여를 안하게 된다. 애초에 손실 값이 bounded되어 있기 때문에 모델은 discriminator를 속이는 데 성공한 샘플만 잘 만들어도 모델이 충분히 수렴하기 때문에 다른 영역의 분포를 근사시킬 이유가 없어지게 된다. 또한 이는 집중된 영역을 제외한 다른 영역에 대해서는 discriminator가 너무도 잘 분별하는 문제가 발생해 Generator가 gradient를 거의 타지 않는 문제가 발생하게 된다. 바로 이 문제를 mode collapse라고 부른다. 이를 수학적으로 설명하면 아래와 같다. 

GAN Loss

 

일단 Generator는 다음과 같은 loss를 줄이려고 하는데 만약에 Discriminator가 완벽하게 동작하게 되면, $P_{G}$에서 오는 D(x)의 경우 0이 되어

GAN Loss에서 Discriminator가 완벽하게 동작하는 경우

 

사실상 Generator가 0이 되어 더 이상 업데이트가 되지 않는 문제가 발생한다.

 

이 문제를 Wasserstein Distance는 critic f라는 개념을 활용해서 이 문제를 풀고자 한다. 개념적으로는 P하고 Q 사이에서 한 분포를 다른 분포로 "옮기기" 위해 필요한 최소 비용을 계산하는 수식이다. 수식은 아래와 같다. 

Wasserstein Distance 본래 수식

 

즉, 여기서 P, Q의 모든 가능한 결합 분포 (joint distribution) 에서 가장 결정적으로 x와 y 사이의 거리를 최소화할 수 있는 하한선 (inf) 을 찾는 것이 바로 WD의 개념이다. 하지만 수학적으로 이 수식을 그대로 사용하는 것은 어려우니, 이를 Kantorovich-Rubinstein Duality 개념을 활용해서 아래와 같은 형태로 수식을 변경할 수 있다. (이 부분은 아직까지 공부를 못해서 나중에 따로 정리를 해야할 거 같음)

Kantorovich-Rubinstein Duality로 인해 변경된 Wasserstein Distance 수식

 

 

결국 저기서 가장 중요한 수식은 바로 $f$와 $Lip_{1}$ 이다. 즉, 기존의 generator의 생성을 판단하는 데에, critic $f$ 함수를 사용해서 실제 데이터 $P$와 Generator의 생성 데이터 $Q$ 간의 비교를 scoring할 수 있는 공간으로 매핑하고 그 공간이 가장 잘 구분될 수 있는 상한선 (sup)을 설정해서, 그 샘플들을 바탕으로 loss를 줄여가는 방법이다. 그리고 이를 위해, 출력의 변화가 입력의 변화보다 더 크게하지 않는 제약식 (1-Lipschitz 제약)을 두어 loss가 mode collapse에 빠지는 것을 방지하였다. 어떻게 그것이 가능했는지 수식으로 디테일을 살펴보자. 

 

일단 우리는 WD를 구하기 위해 f가 최적화가 되어야 한다. 즉, 기존 discriminator처럼 real vs fake를 분류하는 식으로 ( $D(x)$, $1-D(x)$ 형태로 입력을 주지 않고) 하는 것이 아니라, 둘다 $f(x)$, $f(G(x))$ 처럼 $f$ 함수에 묶어서 같은 공간에서 학습시키도록 하는 방법이다. 즉, 이 경우 기존 Discriminator 의 경우의 GAN loss랑 비교하면

GAN의 Discriminator, Generator loss

 

 

이 경우 Discriminator처럼 Generator가 내놓는 것을 Real vs Fake로 분류하는 것이 아니라, 둘의 차이가 어느 정도 있는 공간을 찾아서, 이걸 바탕으로 distance를 계산하고 이를 바탕으로 이 distance를 최소화하는 식으로 동작시킨다. 즉 극단적으로 1-D(x)가 0이 되어 문제가 발생하는 mode collapse 문제를 극복할 수 있는 것이다

 

자 그렇다면 저 $f$가 이해하는 잠재공간이 매우 중요한데, 이 때 이 "둘의 차이가 어느 정도 있는 공간"을 찾는 것이 중요한 부분이 된다. 바로 저 objective function의 $1-Lip$ 제약이랑 $sup$ 부분인데 이 부분을 적용할 수 있는 방법이 바로 gradient penalty라는 방법이다. 

 

우선 저 $1-Lip$ 제약의 개념부터 하나씩 살펴보자

1-Lipschitz 제약

 

즉, 이는 " 입력이 조금 바뀌면 출력도 그만큼만 바뀌어야 한다" 라는 제약으로 $f$를 너무 급격히 튀지 않는 함수로 매핑하고자 한다. 즉, 이는 "부드럽게 변하는 함수를 통해 두 분포의 차이를 최대한 크게 볼 수 있는 방법을 찾는 것"이 WD의 핵심이라고 볼 수 있다.

 

그럼 저 1-Lip 제약을 어떻게 적용할까. 위 수식을 다음과 같은 형태로 바꿀 수 있다.

1-Lipschitz 제약, gradient 형태로 변형 가능

 

즉, 이는 f가 모든 점에 대해 gradient norm이 1이하여야 함을 시사한다. 그러면 이 finding을 마치 loss로 구성을 할 수 있는데 아래와 같이 구성이 가능하다.

GP term
GP 제약 term에 들어갈 real-fake 사이의 interpolation input

 

여기서 중요한 점은 저 GP term을 구성할 때 real과 fake에서 나온 점 중 하나를 가지고 GP term을 만든 것이 아니라, real과 fake의 중간 사이에 있는 interpolate한 점 x_hat을 가지고 GP을 구성한다. 왜냐하면 그 경로상에 있는 점의 gradient가 1 이내가 보장되면 충분하기 때문에, 실제로 interpolation한 다양한 점들을 바탕으로 저 GP term을 구축한다고 한다.

 

따라서 저 GP term이 추가된 최종적인 WGAN의 loss는 아래와 같다.

GP 제약 term을 고려한 최종 WD loss

 

즉, 이를 바탕으로 D가 너무 강해져 수렴이 잘 안되는 문제를 두 분포의 차이를 최대화하도록 잠재공간을 찾는 f를 활용해서 이 문제를 해결하고자 한 것이 바로, WD의 목적이다.