ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [ pytorch ] AutoEncoder 구현하기
    AI-ML 2023. 9. 5. 08:20
    728x90
    반응형

    - 목차

     

    키워드.

    • - Auto Encoder
    • - CNN
    • - Encoder Decoder

     

    들어가며.

    이번 글에서는 pytorch 라이브러리를 사용하여 AutoEncoder 를 구현하는 내용에 대해서 설명하려고 합니다.

    AutoEncoder 의 학습 데이터는 CIFAR10 이미지를 활용하며, AutoEncoder 네트워크의 구조는 아래와 같습니다.

    아래 그림 설명에서 상단 부분이 Auto Encoder 의 Encoder 영역이고, 하단부가 Decoder 영역입니다.

    사용하는 CIFAR10 이미지는 3x32x32 크기의 이미지로 RGB Channel 을 가지며, 32x32 크기를 가집니다.

    Encoder 는 2개의 Convolution Layer 와 Max Pooling Layer 를 사용하구요.

    각 과정에서 Tensor 의 사이즈는 절반씩 줄어듭니다. 그리고 결국에 8x8 크기의 Feature Map 으로 축소됩니다.

    반대로 Decoder 는 Transpose Convolution 연산을 활용하여 Encoder 의 Inverse Transformation 을 구현합니다.

     

     

    pytorch AutoEncoder 작성하기.

     

    AutoEncoder Model .

    아래의 코드는 AutoEncoder Model 을 구현하는 class 입니다.

    Encoder 와 Decoder 모델을 별도로 구현하였구요.

    AutoEncoder 모델이 Encoder 와 Decoder 를 개별적으로 참조하고 있습니다.

    import torch.nn as nn
    
    class AEModelEncoder(nn.Module):
    
        def __init__(self):
            super().__init__()
    
            self.seq = nn.Sequential(
                # 32 -> 32
                nn.Conv2d(3, 16, 2, 1, padding=1),
                nn.ReLU(),
                # 32 -> 16
                nn.MaxPool2d(2, 2),
                # 16 -> 16
                nn.Conv2d(16, 8, 2, 1, padding=1),
                nn.ReLU(),
                # 16 -> 8
                nn.MaxPool2d(2, 2),
            )
    
            self.ReLU = nn.ReLU()
    
        def forward(self, image_tensor):
            output = self.seq(image_tensor)
            return output
    
    class AEModelDecoder(nn.Module):
    
        def __init__(self):
            super().__init__()
    
            self.seq = nn.Sequential(
                # inverse max pooling 8 -> 16
                nn.ConvTranspose2d(8, 8, 2, 2),
                nn.ReLU(),
                # 16 -> 16
                nn.ConvTranspose2d(8, 16, 3, 1, padding=1),
                nn.ReLU(),
                # inverse max pooling 16 -> 32
                nn.ConvTranspose2d(16, 16, 2, 2),
                nn.ReLU(),
                # 32 -> 32
                nn.ConvTranspose2d(16, 3, 3, 1, padding=1),
                nn.ReLU(),
            )
    
        def forward(self, image_tensor):
            output = self.seq(image_tensor)
            return output.view(-1, 3, 32, 32)
    
    class AEModel(nn.Module):
    
        def __init__(self):
            super().__init__()
            self.Encoder = AEModelEncoder()
            self.Decoder = AEModelDecoder()
    
        def forward(self, image_tensor):
            output = self.Encoder(image_tensor)
            output = self.Decoder(output)
            return output

     

    CIFAR10 데이터 학습.

    아래 코드는 CIFAR10 데이터를 기반으로 AutoEncoder 를 학습합니다.

    Loss Function 은 MSE 를 사용하였구요.

    원본 이미지와 AutoEncoder 가 생성한 이미지 텐서의 MSE 를 측정합니다.

     

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    
    transform = transforms.ToTensor()
    ToPILImage = transforms.ToPILImage()
    train_data = datasets.CIFAR10("/tmp/data", train=True, download=True, transform=transform)
    dataloader = DataLoader(train_data, batch_size=100, shuffle=True)
    
    class AEModelEncoder(nn.Module):
    
        def __init__(self):
            super().__init__()
    
            self.seq = nn.Sequential(
                # 32 -> 32
                nn.Conv2d(3, 16, 2, 1, padding=1),
                nn.ReLU(),
                # 32 -> 16
                nn.MaxPool2d(2, 2),
                # 16 -> 16
                nn.Conv2d(16, 8, 2, 1, padding=1),
                nn.ReLU(),
                # 16 -> 8
                nn.MaxPool2d(2, 2),
            )
    
            self.ReLU = nn.ReLU()
    
        def forward(self, image_tensor):
            output = self.seq(image_tensor)
            return output
    
    class AEModelDecoder(nn.Module):
    
        def __init__(self):
            super().__init__()
    
            self.seq = nn.Sequential(
                # inverse max pooling 8 -> 16
                nn.ConvTranspose2d(8, 8, 2, 2),
                nn.ReLU(),
                # 16 -> 16
                nn.ConvTranspose2d(8, 16, 3, 1, padding=1),
                nn.ReLU(),
                # inverse max pooling 16 -> 32
                nn.ConvTranspose2d(16, 16, 2, 2),
                nn.ReLU(),
                # 32 -> 32
                nn.ConvTranspose2d(16, 3, 3, 1, padding=1),
                nn.ReLU(),
            )
    
        def forward(self, image_tensor):
            output = self.seq(image_tensor)
            return output.view(-1, 3, 32, 32)
    
    class AEModel(nn.Module):
    
        def __init__(self):
            super().__init__()
            self.Encoder = AEModelEncoder()
            self.Decoder = AEModelDecoder()
    
        def forward(self, image_tensor):
            output = self.Encoder(image_tensor)
            output = self.Decoder(output)
            return output
    
    model = AEModel()
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(20):
        for i, (train, _) in enumerate(dataloader):
            output = model(train)
            loss = loss_function(output, train)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i % 100 == 0):
                print(epoch, i, loss.item())

     

    반복 횟수에 따른 MSELoss 는 아래와 같습니다.

    이미지 복원해보기.

    아래의 이미지는 첫번째 Epoch 에서 복원한 이미지와 원본 이미지입니다.

    왼쪽의 이미지가 CIFAR10 원본 이미지이며, 오른쪽 이미지가 AutoEncoder 로 Decoding 한 이미지입니다.

     

    그리고 Epoch 이 증가할수록 이미지 복원력의 정확도가 높아지긴합니다.

    다만 약간의 노이즈가 존재하고, 정교한 복원은 어려워보이네요.

     

    반응형
Designed by Tistory.