Machine Learning/Graph model

그래프 모델에서의 추론 - 사슬

진성01 2023. 2. 8. 19:06

앞에서 알아본 그래프 모델을 통해 추론 문제를 해결해 볼 것이다. 여기서는 사슬 형태를 가지는 그래프에서의 추론 문제를 볼 것이다. 사슬 모양 그래프는 다음과 같다.

사슬 모양의 비방향성 그래프

위 그래프의 결합 분포는 다음 형태를 띈다.

N개의 노드들이 각각 K개의 상태를 가지는 이산 변수를 표현한다고 가정하자. 이 경우 포텐셜 함수는 K * K 행렬을 구성하게 되므로 총 (N-1)K^2 개의 매개 변수를 가지게 된다.

이제 관측된 값이 하나도 없을 때 p(xn)을 구하는 법을 알아보자. 이는 다음과 같이 xn을 제외한 모든 변수의 합산을 통해 구할 수 있다.

xn을 제외한 모든 변수를 시그마로 합산하였다.

이 경우 결합 분포는 가능한 x값 하나에 대해 하나씩의 숫자들의 집합으로 표현할 수 있다. 각각 K개의 상태를 가지는 변수가 N개이기 때문에 총 가능한 x의 값은 K^N개이다. 따라서 계산에 필요한 비용은 사슬 길이 N에 대해 기하급수적으로 증가하게 된다.

 

따라서 조금 더 효율적인 방법을 생각해보자. 합산과 곱의 순서를 재배치해보자. 일단, xN에 종속적인 유일한 포텐셜함수를 먼저 계산하자.

xN에 종속적인 유일한 포텐셜 함수

위의 포텐셜 함수를 계산하기 위해선 xN-1만 필요하게 된다. 이러한 방식으로 연쇄적으로 계산하면 다음과 같은 식이 도출된다.

포텐셜함수n-1, n으로부터 연쇄적으로 계산하여 사슬 양끝에 도달하였다.

이렇게 재배열한 이유는 계산량을 줄이기 위한 목적이다. 다음의 원리와 비슷하다.

위의식에서 우변은 세 번의 연산을 필요로하지만, 좌변은 두 번의 연산을 필요로한다. 위의 재배열에서 이러한 곱의 분배 성질을 이용한 것이다.

 

이 재배열 식을 바탕으로 주변 확률을 구하는 데 드는 비용을 계산해보자. 총 N-1개의 합산을 시행해야 하는데, 이때 각각의 합산은 K개의 상태에 대한 것이다. 그리고 각 상태는 두 개의 변수에 대한 함수와 연관되어 있다. 예를 들어 x1에 대한 합산은 함수 ψ1,2(x1, x2)만을 필요로 한다. 이 때 함수는 K * K개의 숫자에 대한 테이블에 해당한다. 각각의 x2값에 대해 이 테이블을 x1에 합산해야 하며, 이 과정에서는 O(K^2)의 계산 비용이 든다. 이러한 합산과 곱셈이 N-1개 있으므로 주변 분포 p(xn)을 계산하는 데는 총 O(NK^2)의 비용이 든다. 가장 쉬운 구현법의 경우 사슬 길이에 따라 기하급수적으로 증가했던 반면 이 경우 사슬 길이에 대해 선형적으로 증가하는 것을 알 수 있다. 이는 그래프의 조건부 독립성을 이용하여 계산을 효율적으로 만들 수 있었다.

 

이 계산법의 해석법에 대해 알아보자. 위의 식을 "메시지를 그래프를 따라 전달한다"는 해석이다. 식 8.52로부터 주변 분포 p(xn)을 다음과 같이 나타낼 수 있다.

그림 8.52 참조

µα(xn)을 노드 xn-1에서 xn으로 앞으로 전달하는 메시지로 해석하자. 그리고 반대로 µβ(xn)은 노드 xn+1에서 xn으로 뒤로 전달되는 메시지이다. 이를 재귀적으로 분해하면 위의 식을 다음과 같이 풀어나갈 수 있다.

µα(xn)를 재귀적으로 풀어나가는 과정

µα(xn)에서 포텐셜함수를 하나와 µα(xn-1)로 분해하였다. 이는 재귀적으로 반복되어 마지막엔 결국 다음 함수만 남을 것이다.

따라서 우린 이 함수의 값을 알아낸다면, 위의 풀어나갔던 과정을 다시 거슬러 올라가 식을 모두 계산할 수 있게 된다. 이와 마찬가지로 µβ(xn)도 가능하다.

µβ(xn)도 재귀적으로 풀어나갈 수 있다.

모든 과정은 다음과 같은 그래프로 나타낼 수 있다.

마르코프 연쇄

다음과 같은 그래프를 마르코프 연쇄(Markov chain)라고 한다.

 

이제 xn뿐만 아니라 사슬의 모든 노드 1, ..., N에 대해 주변 분포를 구한다고 해보자. 위의 과정을 N번 반복한다면 총 O(N^2K^2)의 계산 비용이 들 것이다. 하지만 이 방식은 계산을 매우 낭비하는 일이다. 겹치는 연산이 너무 많기 때문이다. 사실 꼭 필요한 연산 과정은 x1부터 xN까지 메시지를 한 번 전달하고, 반대로 xN부터 x1까지 한 번 전달하면 모두 충족시킬 수 있다. 이 과정에서 중간 계산 메시지들을 저장해 놓기만 한다면 모든 메시지를 처리할 수 있기 때문이다.즉, 다시 말해 모든 노드가 앞 방향 혹은 뒤 방향으로 밖에 메시지를 보내지 못하기 때문에 이 메시지만 다 저장해놔도 중복 계산을 하지 않아도 된다는 것이다. 이렇게 된다면 전체 노드에 대해 주변 분포를 구하는데 단일 분포를 구하는 과정보다 두 배 많은 비용이 필요할 뿐이다.