Pytorch로 시작하는 딥러닝 입문(07-02. 장단기 메모리(Long Short-Term Memory: LSTM))

2024. 6. 23. 14:13딥러닝 모델: 파이토치

♣ 바닐라 RNN(기본 RNN)의 한계

바닐라 RNN은 비교적 짧은 시퀀스에 대해서만 효과를 보인다. time step이 길어질수록 앞의 정보가 뒤로 충분히 전달되지 못하는 현상이 발생한다. 위의 그림은 첫 번째 입력값인 x1이 뒤로 갈수록 색이 점점 옅어지는 것으로 정보량이 손실되어 가는 것을 의미한다. 

 

가장 중요한 정보가 시점의 앞쪽에 위치한다면 바닐라 RNN에서는 엉뚱한 출력값을 내놓을 수 있다. 이를 장기 의존성 문제(the problem of Long-Term Dependencies)라고 한다. 

 

 

♣ 바닐라 RNN 내부 탐색하기

 

위 그림은 바닐라 RNN의 내부 구조이다(편향 b는 생략).  

 

바닐라 RNN은 xt와 ht-1 이라는 두 개의 입력이 각각의 가중치와 곱해져서 메모리 셀의 입력이 된다. 이를 하이퍼볼릭탄젠트 함수의 입력으로 사용하고, 이 값은 은닉층의 출력인 은닉 상태가 된다. 

 

 

♣ LSTM

위 그림은 LSTM의 전체적인 내부 모습을 보여준다. LSTM은 은닉층의 메모리셀에 입력 게이트, 망각 게이트, 출력 게이트를 추가하여 불필요한 기억을 지우고 기억해야 할 것들을 정한다. 즉 은닉 상태를 계산하는 식이 바닐라 RNN보다 복잡해졌으며 셀 상태(cell state)라는 값을 추가했다. 위 그림에서는 t시점의 셀 상태를 Ct로 표현한다. LSTM은 바닐라 RNN과 비교하여 긴 시퀀스의 입력을 처리하는 데 탁월한 성능을 보인다. 

 

 

셀 상태는 위의 그림에서 왼쪽에서 오른쪽으로 가는 굵은 선이다. 이전 시점의 셀 상태가 다음 시점의 셀 상태를 구하기 위한 입력으로 사용된다. 

 

은닉 상태값과 셀 상태값을 구하기 위해서 새로 추가 된 3개의 게이트를 사용한다. 각 게이트는 삭제 게이트, 입력 게이트, 출력 게이트라고 부르며 이 3개의 게이트에는 공통적으로 시그모이드 함수가 존재한다. 시그모이드 함수를 지나면 0과 1사이의 값이 나오게 되는데 이 값들을 가지고 게이트를 조절한다. 

 

 

입력 게이트

입력 게이트는 현재 정보를 기억하기 위한 게이트이다. 현재 시점 t의 x값과 입력 게이트로 이어지는 가중치 Wxi를 곱한 값과 이전 시점 t-1의 은닉 상태가 입력 게이트로 이어지는 가중치 Whi를 곱한 값을 더하여 시그모이드 함수를 지난다. 이를 it라고 한다. 

 

그리고 현재 시점 t의 x값과 입력 게이트로 이어지는 가중치 Wxg를 곱한 값과 이전 시점 t-1의 은닉 상태가 입력 게이트로 이어지는 가중치 Whg를 곱한 값을 더하여 하이퍼볼릭탄젠트 함수를 지난다. 이를 gt라고 한다. 

 

시그모이드 함수를 지나 0과 1 사이의 값과 하이퍼볼릭탄젠트 함수를 지나 -1과 1 사이의 값 두 개가 나오게 된다. 이 두 값을 이용하여 이번에 선택된 기억할 정보의 양을 정한다. 

 

 

삭제 게이트

 

기억을 삭제하기 위한 게이트이다. 현재 시점 t의 x값과 이전 시점 t-1의 은닉 상태가 시그모이드 함수를 지나게 된다. 시그모이드 함수를 지나면 0과 1 사이의 값이 나오는데, 이 값이 곧 삭제 과정을 거친 정보의 양이다. 0에 가까울수록 정보가 많이 삭제된 것이고 1에 가까울수록 정보를 온전히 기억한 것이다. 이를 가지고 셀 상태를 구하게 된다. 

 

 

셀 상태(장기 상태)

 

셀 상태 Ct를 장기 상태라고도 부른다. 셀 상태를 구하는 방법을 알아보자. 

 

입력 게이트에서 구한 it, gt 두 값에 대해 원소별 곱을 진행한다. 같은 크기의 두 행렬이 있을 때 같은 위치의 성분끼리 곱하는 것이다. 여기서는 식으로 ○ 로 표현한다. 이 값이 이번에 선택된 기억할 값이다. 

 

입력 게이트에서 선택된 기억을 삭제 게이트의 결과값과 더한다. 이 값을 현재 시점 t의 셀 상태라고 하며, 이 값은 다음 t+1 시점의 LSTM 셀로 넘겨진다. 

 

만약 삭제 게이트의 출력값인 ft가 0이 된다면 이전 시점의 셀 상태값인 Ct-1은 현재 시점의 셀 상태값을 결정하기 위한 영향력이 0이 된다. 0이면 모든 정보가 삭제된 것이기 때문이다. 그렇기 때문에 오직 입력 게이트의 결과만이 현재 시점의 셀 상태값 Ct를 결정할 수 있다. 이는 삭제 게이트가 완전히 닫히고 입력 게이트를 연 상태를 의미한다. 

 

반대로 입력 게이트의 it 값을 0이라고 한다면 현재 시점의 셀 상태값 Ct는 오직 이전 시점의 셀 상태값 Ct-1에만 의존한다. 이는 입력 게이트를 완전히 닫고 삭제 게이트만을 연 상태를 의미한다. 

 

결과적으로, 삭제 게이트는 이전 시점의 입력을 얼마나 반영할지를 의미하고, 입력 게이트는 현재 시점의 입력을 얼마나 반영할지를 결정한다. 

 

 

 

출력 게이트와 은닉 상태(단기 상태)

출력 게이트는 현재 시점 t의 x값과 이전 시점 t-1의 은닉 상태가 시그모이드 함수를 지난 값이다. 해당 값은 현재 시점 t의 은닉 상태를 결정하는 일에 쓰인다. 

 

은닉 상태를 단기 상태라고도 한다. 은닉 상태는 장기 상태의 값이 하이퍼볼릭탄젠트 함수를 지나 -1과 1 사이의 값이다. 해당 값은 출력 게이트의 값과 연산되면서 값이 걸러지는 효과가 발생한다. 단기 상태의 값은 출력층으로도 향한다. 

 

 

 

♣ 파이토치의 nn.LSTM( )

파이토치에서 LSTM 셀을 사용하는 방법은 매우 간단하다. 기존의 RNN셀을 사용하려고 했을 때의 코드는 다음과 같다. 

 

LSTM셀은 이와 유사하게 다음과 같이 사용한다.