Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기
DL/RNN

[DL][RNN] LSTM(Long Short - Term Memory) 구조 및 PyTorch 구현

by 어떻게든 되겠지~ 2024. 8. 27.

 

 

LSTM(Long Short - Term Memory) 이란

LSTM은 RNN 기법 중 하나로 Cell, Input Gate, Output Gate, Forget Gate를 이용해 기존 RNN의 문제인 기울기 소멸 문제(Vanishing Gradient)를 방지하도록 개발된 모델이다

 

RNN은 이전 단계의 출력을 다음 단계의 입력으로 사용하는 순환 구조로, Sequence 데이터에서 패턴을 학습하는 데 적합하다     
하지만, RNN은 긴 시퀀스에서 초기 정보가 뒤로 갈수록 희미해지는 '장기 의존성 문제(Long-term dependency problem)'가 발생할 수 있다. 이는 RNN의 Gradient가 시간 경과에 따라 급격히 커지거나 작아지는 '기울기 소멸(vanishing gradient)'이나 '기울기 폭발(exploding gradient)'로 인해 발생한다

 

따라서 이러한 문제를 해결하기 위해 LSTM은 정보를 선택적으로 기억하고, 잊고, 출력하는 과정을 제어하는 "Cell State""Gate"구조를 사용한다

 

LSTM의 세가지 주요 Gate

  • Input Gate(입력 게이트) : 현재 입력 정보를 얼마나 기억할지 결정
  • Forget Gate(망각 게이트) : 이전 Cell State 정보를 얼마나 잊을지 결정
  • Output Gate(출력 게이트) : 현재 Cell State에서 얼마나 정보를 출력할지 결정

LSTM의 동작과정

1. Forget Gate

이 단계에서는 이전 Cell State의 정보를 얼마나 잊을지 결정한다

시그모이드 함수(σ)의 값이 0에 가까우면 이전 Cell Stete 정보를 잊고, 1에 가까우면 정보를 유지한다

2. Input Gate

이 단계에서는 현재 Input이 Cell State에 얼마나 반영될지를 결정한다

시그모이드 함수의 값과 함께 현재 Input에 대한 새로운 정보가 tanh를 통해 생성되어 Cell State에 추가된다

3. Cell State Update

Forget Gate와 Input Gate를 통해 결정된 정보들이 반영되어 Cell State가 Update 된다

4. Output Gate

Update 된 Cell State를 바탕으로 최종적으로 출력할 정보를 결정한다

시그모이그 함수로 출력의 중요성을 결정하고, 하이퍼볼릭 탄젠트 함수를 거친 Cell State와 곱하여 최종 출력을 생성한다

LSTM의 장점

  • 장기 의존성 문제 해결
  • 유연한 정보 저장 및 삭제
  • 다양한 시퀀스 데이터에 적용가능

LSTM(Long Short - Term Memory) 구조

Hidden state(ht) : 단기 정보 제공      
Cell State(ct) : 장기적으로 정보 유지

Forget Gate(망각 게이트)

망각 게이트는 과거 정보를 어느 정도 잊을지를 결정한다             
과거 은닉 상태(ht1)와 현재 데이터({x_t})를 입력 받아 시그모이드를 취한 후(ft),

그 값(ft)을 과거 Cell State(Ct1)에 곱한다     
이 때, 시그모이드의 출력(ft)이 0이면 과거 Cell State 정보는 버리고, 1 이면 과거 Cell State 정보는 온전히 보존한다

즉, 과거 Cell State에서 사용하지 않을 데이터에 대한 가중치로 사용된다

0 ~ 1 사이의 출력 값을 가지는 ht1과 xt를 입력값으로 받는다        
이 때, xt는 새로운 입력값이고, ht1은 이전 Hidden Layer에서 입력 되는 값이다       
즉, ht1와 xt를 이용하여 이전 Cell State 정보를 현재 반영할지 결정하는 역할을 한다       

ft=σ(Wf[ht1,xt]+bf)=σ(Wxfxt+Whfht1+bf)

- 계산한 값이 1이면 이전의 Cell State 정보를 유지
- 계산한 값이 0이면 초기화

 

Input Gate(망각 게이트)

현재 Input이 Cell State에 얼마나 반영될지를 결정한다
과거 정보(ht1)와 현재 데이터(xt)를 입력받아 시그모이드와 하이퍼볼릭 탄젠트 함수를 기반으로 현재 정보에 대한 보존량을 결정        
현재 시점에서 새롭게 들어온 정보를 셀 상태에 추가할지, 추가한다면 얼마나 반영할지를 결정하는 역할을 한다 (1이면 입력이 들어올 수 있도록 허용, 0이면 차단)

 

 

1. Input Gate 계산

it는 Input Gate의 출력으로, 현재 입력된 정보가 Cell State에 얼마나 반영될지, 즉 현재 전보의 보존량을 결정한다.

값이 1에 가까울수록 정보가 많이 반영되고 0에 가까울수록 정보 반영이 차단된다

2. 임시 Cell State 계산

~ct는 현재 입력과 이전 은닉 상태 정보를 요약한 임시 Cell State로, 현재 시점에서 Cell State에 반영될 후보 정보이다

 

it=σ(Wi[ht1,xt])+bi=σ(Wxixt+Whiht1+bi)~ct=tanh(Wc[ht1,xt])+bc=tanh(Wxcxt+Whcht1+bc)

 

 

 

아래 그림은 Forget Gate와 Input Gate 예시 그림입니다

 

Cell State Update

Cell State는 LSTM의 "메모리" 역할을 하며, 시퀀스의 각 타임스텝에서 중요한 정보를 기억하고 유지하는 기능을 한다.
'총합(sum)'을 사용하여 셀 값을 반영하며, 이것으로 기울기 소멸 문제가 해결된다.      
셀 업데이트 방법   
Forget Gate 와 Input Gate의 이전 단계 셀 정보를 계산하여 현재 단계의 셀 상태(Cell state)를 업데이트 한다

 

ct=ftct1+it~ct
- 이전 정보 유지 : 망각 게이트 ft 의 결정에 따라 이전 셀 상태 ct1 의 정보를 일부 유지하거나 삭제
- 새로운 정보 추가 : 입력 게이트 it 임시 Cell State ~ct의 곱을 통해 새로운 정보를 셀 상태에 추가

Output Gate(출력 게이트)

현재 타임스텝에서 어떤 정보를 Hidden State(ht)로 전달할지를 결정한다.        
Output Gate는 이전 은닉 상태(ht1)와 t번째 입력(xt)과 Update 된 셀 Cell state를 기반으로 하여 현재 정보를 얼마나 Hidden State로 보낼지를 제어한다       

 

ot=σ(wo[ht1,xt]+bo)=σ(Wxoxt+Whoht1+bo)ht=ottanh(ct)

- 계산한 값이 1이면 의미 있는 결과로 최종 출력
- 계산한 값이 0이면 해당 연산 결과를 출력하지 않음

 

 

LSTM 전체 게이트 Flow

  1. Gate 계산하기(ft,it~ct)
  2. Cell State Update
  3. Hidden State update

 


LSTM(Long Short - Term Memory) 역전파

LSTM은 Cell을 통해 역전파를 수행하기 때문에 '중단 없는 기울기(Uninterrupted Gradient flow)'라고도 한다 (LSTM의 셀 상태(cell state) 덕분에 기울기 소실(vanishing gradient) 문제를 해결할 수 있다는 점을 강조한 것)             
즉, 최종 오차는 모든 노드에 전파되는데, 이때 Cell을 통해 중단없이 전파한다

LSTM의 Cell State는 아래와 같은 이유로 'Uninterrupted Gradient Flow'를 유지할 수 있다
1. Forget Gate와 Input Gate의 역할     
    Forget Gate와 Input Gate는 Cell State에서 필요한 정보만을 선택적으로 잊거나 추가한다        
    이 때문에 Gradient가 크게 변동하지 않고, 장기적인 의존성을 유지할 수 있게 된다      
2. Cell State의 직선적 경로     
   Cell State는 LSTM 내에서 곱셈과 덧셈 연산을 통해 변화        
   특히, Forget Gate와 Input Gate는 곱셈(정보를 잊거나 보존)과 덧셈(새로운 정보를 추가)만으로 Update 된다       
   이러한 구조 덕분에 Gradient가 곱셈을 통해 기울기가 소멸되는 것을 방지할 수 있다(왜냐하면 RNN에서는 계속 곱합으로써 기울기가 소멸)     
   즉, Cell State는 Time Step에 따라 정보가 흐르는 직선 경로를 제공하므로, 기울기가 소실되지 않고 역방향으로 전파된다       
3.
Gradient 소실 방지
   Cell State는 각 Time Step에서 Update되지만 필요할때만 변경되므로, 불필요한 정보로 인해 Gradient가 왜곡되는 것을 막아준다     
   결과적으로 LSTM은 매우 긴 시퀀스에서도 중요한 정보를 유지하면서 학습할 수 있게 된다   


 

이제 PyTorch를 통해 LSTM Cell 및 LSTM 계층을 구현해보도록 하겠습니다

 

1) 라이브러리 호출

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter # 파라미터 목록을 갖고있는 라이브러리
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cuda = True if torch.cuda.is_available() else False # GPU 사용에 필요

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # GPU 사용에 필요

torch.manual_seed(125)
if torch.cuda.is_available():
    torch.cuda.manual_seed(125)

2) 데이터 셋 준비 및 전처리

import torchvision.transforms as transforms

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (1.0, ))
])

 

from torchvision.datasets import MNIST
# MNIST : 훈련데이터 7만개, 검증 데이터 1만개
download_root = '../chapter7/data/MNIST_DATASET'

train_dataset = MNIST(download_root, train=True, transform=mnist_transform, download=True)
valid_dataset = MNIST(download_root, train=False, transform=mnist_transform, download=True)
test_dataset = MNIST(download_root, train=False, transform=mnist_transform, download=True)

 

Fashion MNIST 데이터 셋을 다운받고 정의한 전처리 과정에 따라 전처리를 수행합니다

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

 

3) LSTM Cell 구현

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        #####
        # 왜 4 * hidden_size 인가?
        self.x2h = nn.Linear(input_size, 4*hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4*hidden_size, bias=bias)
        #####
        
        self.reset_parameters()

    # 모델 파라미터 초기화
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std) # 난수를 위해서 사용

    def forward(self, x, hidden):
        ht_1, ct_1 = hidden
        x = x.view(-1, x.size(1))

        gates = self.x2h(x) + self.h2h(ht_1)
        gates = gates.squeeze()
        input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)

        forget_gate = F.sigmoid(forget_gate)    # 망각게이트에 시그모이드 활성화 함수 적용
        input_gate = F.sigmoid(input_gate)      # 입력게이트에 시그모이드 활성화 함수 적용
        cell_gate = F.tanh(cell_gate)           # 셀 게이트에 하이퍼볼릭 탄젠트 활성화 함수 적용
        output_gate = F.sigmoid(output_gate)    # 출력게이트에 시그모이드 활성화 함수 적용

        # Cell State는 망각 게이트와 입력 게이트의 합
        # f_t * c_{t-1} + i_t * c_t
        # Element - wise Product
        ct = torch.mul(forget_gate, ct_1) + torch.mul(input_gate, cell_gate)
        ht = torch.mul(output_gate, F.tanh(ct))
        return (ht, ct)

 

4*hidden_size인 이유

게이트는 Forget Gate , Input Gate, Cell, Output Gate로 구성되어 있다        
따라서 위에서 계산된 x2h(x) + h2h(x)를 네 개로 쪼개지는 상황이기 때문에 은닉층의 유닛 개수에 4를 곱해준 것이다

그 후, torch.chunk(chunks=4, dim=1)을 통해 각각의 Gate로 쪼개어 사용한다

3) LSTM Cell을 사용한 LSTM Model 구현

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim

        self.layer_dim = layer_dim
        self.lstm = LSTMCell(input_size=input_dim, hidden_size=hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        if torch.cuda.is_available():
            # (은닉층의 계층 개수, 배치크기, 은닉층의 뉴런 개수) 형태를 갖는 Hidden State를 0으로 초기화
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
            # (은닉층의 계층 개수, 배치크기, 은닉층의 뉴런 개수) 형태를 갖는 셀 상태를 0으로 초기화
            c0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda()) 
        else:
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
            c0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))

        outs = []
        # (은닉층 계층의 개수, 배치크기, 은닉층의 뉴런 개수) 크기를 갖는 Cell, Hidden State에 대한 Tensor
        hn = h0[0, :, :]
        cn = c0[0, :, :]

        # LSTM Cell 계층을 반복하여 쌓아올림
        for seq in range(x.size(1)):
            hn, cn = self.lstm(x[:, seq, :], (hn, cn))
            outs.append(hn)

        out = outs[-1].squeeze()
        out = self.fc(out)
        return out

4) Loss Function, Optimizer 정의 및 모델 선언

# MNIST 데이터 셋 : [60000, 28, 28]
input_dim = 28
hidden_dim = 128
# 2 Layer LSTM
layer_dim = 2
# MNIST Data class수
output_dim = 10

model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
if torch.cuda.is_available():
    model.cuda()

criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

 

Layer_dim을 2로 설정하여 Multi Layer LSTM으로 설정하였다

2 Layer  LSTM의 구조
2계층 LSTM은 두 개의 LSTM 레이어가 순차적으로 연결된 구조

1 Layer LSTM: Input Data를 처리하고, 그 출력(은닉 상태)이 2계층 LSTM의 입력으로 전달
2 Layer LSTM: 1계층 LSTM의 출력(Hidden State)을 입력으로 받아 다시 처리한 후 최종 출력을 생성

1. 1 Layer LSTM
  - Input : Input Sequence X=[x1,x2,,xt]와 초기 Hidden State h0(1) 및 초기 Cell State C0(1)
  - Forget Gate : 입력 xt와 Hidden State(h(1)t1)를 이용해 1 Layer의 Forget Gate 계산
  - Input Gate : 동일하게 xt와 h(1)t1을 사용해 1 Layer의 Input Gate 계산
  - Cell State : Forget Gate와 Input Gate의 값으로 Cell State Ct(1)가 Update
  - Output Gate : Upate 된 Cell State Ct(1)로부터 현재 Hidden State (ht(1))가 계산 되어 2 Layer LSTM으로 전달

2. 2 Layer LSTM
  - Input : 1 Layer LSTM의 Hidden State 출력 ht(1)이 2 Layer LSTM의 입력이 된다
  - Forget Gate : 2 Layer LSTM에서는 입력 ht(1)와 2 Layer LSTM의 이전 Hidden State h(2)t1를 사용해 Forget Gate가 계산
  - Input Gate : 2계층의 Input Gate는 ht(1)과 h(2)t1를 사용해 계산
  - Cell State : Forget Gate와 Input Gate 결과로 2 Layer LSTM Cell State Ct(2)가 Update
  - Output Gate : Upate 된 Cell State Ct(2)로부터 2 Layer LSTM의 현재 Hidden State ht(2)가 계산되고, 이 Hidden State가 모델의 최종 출력이 된다

5) 모델 학습 함수

seq_dim = 28 
loss_list = []
iter = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        
        # images.shape : [64, 1, 28, 28]    
        # MNIST Data set이 흑백(Channel이 1이므로 차원 1개 삭제)
        if torch.cuda.is_available():
            images = Variable(images.view(-1, seq_dim, input_dim).cuda())
            labels = Variable(labels.cuda())
        else:
            images = Variable(images.view(-1, seq_dim, input_dim))
            labels = Variable(labels)
        # images.shape[64, 28, 28]

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        if torch.cuda.is_available():
            loss.cuda()

        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        iter += 1

        if iter % 500 == 0:
            correct = 0
            total = 0
            for images, labels in valid_loader:
                if torch.cuda.is_available():
                    images = Variable(images.view(-1, seq_dim, input_dim).cuda())
                else:
                    images = Variable(images.view(-1, seq_dim, input_dim))
                
                outputs = model(images)
                _, pred = torch.max(outputs.data, 1)
                total += labels.size(0)

                if torch.cuda.is_available():
                    correct += (pred.cpu() == labels.cpu()).sum()
                else:
                    correct += (pred == labels).sum()

            accuracy = 100 * correct / total    
            print('Iteration : {} / {}, Loss : {}, Accuracy : {}'.format(iter, n_iters*num_epochs , loss.item(), accuracy))

6) 모델 평가함수

def evaluate(model, val_iter):
    corrects, total, total_loss = 0, 0, 0
    model.eval()

    for images, labels in val_iter:
        if torch.cuda.is_available():
            images = Variable(images.view(-1, seq_dim, input_dim).cuda())
        else:
            images = Variable(images.view(-1, seq_dim, input_dim))

        logit = model(images).to(device)
        loss = F.cross_entropy(logit, labels, reduction='sum')
        _, pred = torch.max(logit.data, 1)
        total += labels.size(0)
        total_loss += loss.item()
        corrects += (pred == labels).sum()

        avg_loss = total_loss / len(val_iter.dataset)
        avg_accuracy = corrects/total
        return avg_loss, avg_accuracy
반응형