본문 바로가기
DL/CNN

[DL][CNN] 설명 가능한 AI (Explainable Artificial Intelligence, XAI)와 Feature Map 시각화, PyTorch 예제

by 어떻게든 되겠지~ 2025. 1. 15.

1. Explainable Artificial Intelligence(XAI) 란?

Explainable AI란 Deep Learning 처리 결과를 사람이 이해할 수 있는 방식으로 제시하는 기술이다. Deep Learning에서 Model 내부는 Black Box 같아 내부에서 어떻게 동작하는지 설명하기 어렵다.

따라서 Deep Learning을 통해 얻은 결과는 신뢰하기 어려운데 처리 과정을 시각화해야 할 필요성이 있다.

 

Model을 구성하는 각 중간 게층부터 최종 분류기까지 입력된 이미지에 대해 Feature map이 어떻게 추출되고 학습하는지를 시각적으로 설명할 수 있어야만 결과에 대한 신뢰성을 얻을 수 있다.

 

이제부터 CNN 내부 과정에 대한 시각화를 진행하는데 시각화 방법에는 Filter에 대한 시각화, Feature map에 대한 시각화가 있다.

 

2. Feature Map 시각화

Feature Map은 입력 이미지 또는 다른 Feature Map처럼 필터를 입력에 적용한 결과이다.      
따라서 Feature Map을 시각화한다는 의미는 Feature Map에서 입력 특성을 감지하는 방법을 이해할 수 있도록 돕는 것이다.

1) 필요 라이브러리 호출

import matplotlib.pyplot as plt
from PIL import Image
import cv2

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import torchvision.models as models

device = "cuda" if torch.cuda.is_available() else "cpu"

2) 모델 정의 및 선언

class XAI(torch.nn.Module):
    def __init__(self, num_classes=2):
        super(XAI, self).__init__()
        # input : 100 x 100 x 3
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512, bias=False),
            nn.Dropout(0.5),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return F.log_softmax(x)
        
        
model = XAI()
model.to(device)

3) Feature Map 확인을 위한 Class 정의

class LayerActivations:
    features = []   
    def __init__(self, model, layer_sum):
        self.hook = model[layer_sum].register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output.detach().numpy() 
        
    def remove(self):
        self.hook.remove()

 

※ hook

PyTorch는 hook 기능을 사용하여 각 Layer의 Activation Function 및 Gradient 값을 확인할 수 있다. 따라서, register_forward_hook()의 목적은 Forward 과정 중에 각 Network Module의 입력 및 출력을 가져오는 것이다.

 

4) Image 확인

img = cv2.imread("../data/cat.jpg")    
plt.imshow(img)
img = cv2.resize(img, (100, 100), interpolation=cv2.INTER_LINEAR)  # INTER_LINEAR : 보간법000
img = ToTensor()(img).unsqueeze(0) 
print(img.shape)

출력 결과

5) 첫 번째 Convolution Layer에서의 Feature Map 시각화

result = LayerActivations(model.features, 0)

model(img)  
activations = result.features

fig, axes = plt.subplots(4, 4)  
fig = plt.figure(figsize=(12,8))    
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) 

for row in range(4):
    for column in range(4):
        axis = axes[row][column]
        axis.get_xaxis().set_ticks([])
        axis.get_yaxis().set_ticks([])
        axis.imshow(activations[0][row*10 + column])
plt.show();

6) 20 번째 Convolution Layer에서의 Feature Map 시각화

result = LayerActivations(model.features, 20)

model(img)
activations = result.features

fig, axes = plt.subplots(4, 4)  
fig = plt.figure(figsize=(12,8))    
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) 

for row in range(4):
    for column in range(4):
        axis = axes[row][column]
        axis.get_xaxis().set_ticks([])
        axis.get_yaxis().set_ticks([])
        axis.imshow(activations[0][row*10 + column])
plt.show()

7) 40 번째 Convolution Layer에서의 Feature Map 시각화

result = LayerActivations(model.features, 40)

model(img)
activations = result.features

fig, axes = plt.subplots(4, 4)  
fig = plt.figure(figsize=(12,8))    
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) 

for row in range(4):
    for column in range(4):
        axis = axes[row][column]
        axis.get_xaxis().set_ticks([])
        axis.get_yaxis().set_ticks([])
        axis.imshow(activations[0][row*10 + column])
plt.show()

 

0번째, 20번째 40번쨰 Convolution Layer로 갈수록 Image의 원래 형태를 찾아볼 수 없고 Image의 특징들만 전달되는 것을 확인할 수 있습니다.

반응형