Pytorch로 시작하는 딥러닝 입문(07-04. 문자 단위 RNN: Char RNN)

2024. 6. 24. 20:49딥러닝 모델: 파이토치

모든 시점의 입력에 대해서 모든 시점에서 출력을 하는 다대다 RNN

 

♣ 문자 단위 RNN

입출력의 단위를 단어 레벨이 아니라 문자 레벨로 하여 RNN을 구현한 것을 문자 단위 RNN이라고 한다. 

RNN 구조는 같지만 입, 출력의 단위가 문자로 바뀌었을 뿐이다. 

 

필요한 도구 import 하기

 

훈련 데이터 전처리하기

문자 시퀀스 apple을 입력받으면 pple!를 출력하는 RNN을 구현한다. RNN의 동작을 이해하는 것이 목적이다. 

입력 데이터와 레이블 데이터에 대해서 문자 집합(vocabulary)를 만든다. 이 문자 집합은 중복을 제거한 문자들의 집합이다. 

현재 문자 집합에는 총 5개의 문자가 있다. !, a, e, l, p이다. 이제 하이퍼파라미터를 정의하는데, 이때 입력은 원-핫 벡터를 사용할 것이므로 입력의 크기는 문자 집합의 크기여야만 한다. 

 

이제 문자 집합에 고유한 정수를 부여한다. 

 

나중에 예측 결과를 다시 문자 시퀀스로 보기 위해서 반대로 정수로부터 문자를 얻을 수 있는 index_to_char을 만든다. 

 

입력 데이터와 레이블 데이터의 각 문자들을 정수로 맵핑한다.

 

파이토치 nn.RNN( )은 기본적으로 3차원 데이터를 입력받는다. 그 때문에 배치 차원을 추가한다. 

 

 

입력 시퀀스의 각 문자들을 원-핫 벡터로 바꾼다. 

 

 

입력 데이터와 레이블 데이터를 텐서로 바꾼다. 

 

 

모델 구현하기

 

 

 

모델에 입력을 넣어서 출력 확인하기

 

[1, 5, 5] 는 각각 배치 차원, 시점(time steps), 출력의 크기이다. 나중에 정확도를 측정할 때 이를 모두 펼쳐서 계산하는데, 이때는 view를 사용하여 배치 차원과 시점 차원을 하나로 만든다. 

 

 

 

레이블 데이터의 크기 확인하기

레이블 데이터는 (1, 5)의 크기를 가지는데, 마찬가지로 나중에 정확도를 측정할 때 이를 펼쳐서 계산한다. 이때 [5] 의 크기를 가지게 된다.