-
[ pytorch ] AutoEncoder 구현하기AI-ML 2023. 9. 5. 08:20728x90반응형
- 목차
키워드.
- - Auto Encoder
- - CNN
- - Encoder Decoder
들어가며.
이번 글에서는 pytorch 라이브러리를 사용하여 AutoEncoder 를 구현하는 내용에 대해서 설명하려고 합니다.
AutoEncoder 의 학습 데이터는 CIFAR10 이미지를 활용하며, AutoEncoder 네트워크의 구조는 아래와 같습니다.
아래 그림 설명에서 상단 부분이 Auto Encoder 의 Encoder 영역이고, 하단부가 Decoder 영역입니다.
사용하는 CIFAR10 이미지는 3x32x32 크기의 이미지로 RGB Channel 을 가지며, 32x32 크기를 가집니다.
Encoder 는 2개의 Convolution Layer 와 Max Pooling Layer 를 사용하구요.
각 과정에서 Tensor 의 사이즈는 절반씩 줄어듭니다. 그리고 결국에 8x8 크기의 Feature Map 으로 축소됩니다.
반대로 Decoder 는 Transpose Convolution 연산을 활용하여 Encoder 의 Inverse Transformation 을 구현합니다.
pytorch AutoEncoder 작성하기.
AutoEncoder Model .
아래의 코드는 AutoEncoder Model 을 구현하는 class 입니다.
Encoder 와 Decoder 모델을 별도로 구현하였구요.
AutoEncoder 모델이 Encoder 와 Decoder 를 개별적으로 참조하고 있습니다.
import torch.nn as nn class AEModelEncoder(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( # 32 -> 32 nn.Conv2d(3, 16, 2, 1, padding=1), nn.ReLU(), # 32 -> 16 nn.MaxPool2d(2, 2), # 16 -> 16 nn.Conv2d(16, 8, 2, 1, padding=1), nn.ReLU(), # 16 -> 8 nn.MaxPool2d(2, 2), ) self.ReLU = nn.ReLU() def forward(self, image_tensor): output = self.seq(image_tensor) return output class AEModelDecoder(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( # inverse max pooling 8 -> 16 nn.ConvTranspose2d(8, 8, 2, 2), nn.ReLU(), # 16 -> 16 nn.ConvTranspose2d(8, 16, 3, 1, padding=1), nn.ReLU(), # inverse max pooling 16 -> 32 nn.ConvTranspose2d(16, 16, 2, 2), nn.ReLU(), # 32 -> 32 nn.ConvTranspose2d(16, 3, 3, 1, padding=1), nn.ReLU(), ) def forward(self, image_tensor): output = self.seq(image_tensor) return output.view(-1, 3, 32, 32) class AEModel(nn.Module): def __init__(self): super().__init__() self.Encoder = AEModelEncoder() self.Decoder = AEModelDecoder() def forward(self, image_tensor): output = self.Encoder(image_tensor) output = self.Decoder(output) return output
CIFAR10 데이터 학습.
아래 코드는 CIFAR10 데이터를 기반으로 AutoEncoder 를 학습합니다.
Loss Function 은 MSE 를 사용하였구요.
원본 이미지와 AutoEncoder 가 생성한 이미지 텐서의 MSE 를 측정합니다.
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt transform = transforms.ToTensor() ToPILImage = transforms.ToPILImage() train_data = datasets.CIFAR10("/tmp/data", train=True, download=True, transform=transform) dataloader = DataLoader(train_data, batch_size=100, shuffle=True) class AEModelEncoder(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( # 32 -> 32 nn.Conv2d(3, 16, 2, 1, padding=1), nn.ReLU(), # 32 -> 16 nn.MaxPool2d(2, 2), # 16 -> 16 nn.Conv2d(16, 8, 2, 1, padding=1), nn.ReLU(), # 16 -> 8 nn.MaxPool2d(2, 2), ) self.ReLU = nn.ReLU() def forward(self, image_tensor): output = self.seq(image_tensor) return output class AEModelDecoder(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( # inverse max pooling 8 -> 16 nn.ConvTranspose2d(8, 8, 2, 2), nn.ReLU(), # 16 -> 16 nn.ConvTranspose2d(8, 16, 3, 1, padding=1), nn.ReLU(), # inverse max pooling 16 -> 32 nn.ConvTranspose2d(16, 16, 2, 2), nn.ReLU(), # 32 -> 32 nn.ConvTranspose2d(16, 3, 3, 1, padding=1), nn.ReLU(), ) def forward(self, image_tensor): output = self.seq(image_tensor) return output.view(-1, 3, 32, 32) class AEModel(nn.Module): def __init__(self): super().__init__() self.Encoder = AEModelEncoder() self.Decoder = AEModelDecoder() def forward(self, image_tensor): output = self.Encoder(image_tensor) output = self.Decoder(output) return output model = AEModel() loss_function = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.01) for epoch in range(20): for i, (train, _) in enumerate(dataloader): output = model(train) loss = loss_function(output, train) optimizer.zero_grad() loss.backward() optimizer.step() if (i % 100 == 0): print(epoch, i, loss.item())
반복 횟수에 따른 MSELoss 는 아래와 같습니다.
이미지 복원해보기.
아래의 이미지는 첫번째 Epoch 에서 복원한 이미지와 원본 이미지입니다.
왼쪽의 이미지가 CIFAR10 원본 이미지이며, 오른쪽 이미지가 AutoEncoder 로 Decoding 한 이미지입니다.
그리고 Epoch 이 증가할수록 이미지 복원력의 정확도가 높아지긴합니다.
다만 약간의 노이즈가 존재하고, 정교한 복원은 어려워보이네요.
반응형'AI-ML' 카테고리의 다른 글
[ CNN ] Feature Map 이해하기 (0) 2023.09.19 [ pytorch ] ConvTranspose2d 알아보기 (0) 2023.09.11 [pytorch] Dropout 알아보기 (0) 2023.08.17 Association Rules (연관규칙) 이해하기 (0) 2023.05.16 [ pytorch ] MaxPool2d, AvgPool2d 알아보기 ( Pooling Layer ) (0) 2023.03.27