ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [PyTorch] squeeze, unsqueeze 알아보기
    AI-ML 2024. 5. 21. 06:22
    728x90
    반응형

    - 목차

     

    들어가며.

    PyTorch 는 Tensor 의 차원을 늘리거나 줄일 수 있는 함수를 제공합니다.

    이러한 함수의 이름은 squeeze 와 unsqueeze 함수인데요.

    squeeze 라는 영단어의 뜻이 "(특히 손가락으로 꼭) 짜다[쥐다]", "(무엇에서 액체를) 짜내다[짜다]" 라고 하는데요.

    이런 의미처럼 squeeze 함수를 통해서 Tensor 의 차원을 줄이고, unsqueeze 함수를 활용하여 차원을 늘릴 수 있습니다.

     

    squeeze 함수와 unsqueeze 함수는 무턱대로 차원을 변형하지는 않습니다.

    차원의 Length 가 1 인 경우가 늘리거나 줄어드는 대상이 됩니다.

    예를 들어, 흑백의 이미지와 RGB 컬러의 이미지가 존재한다고 가정하겠습니다.

    일반적인 이미지는 Width 와 Height 인 크기 정보가 존재합니다.

    즉, 이미지는 기본적으로 2차원의 데이터입니다.

    여기에 흑백의 이미지의 경우에는 Gray Scale 의 차원이 하나가 추가됩니다.

    반면 컬러 이미지의 경우에는 Red-Green-Blue 인 차원이 추가되죠.

    이를 수치적으로 표현하면 아래와 같습니다.

     

    28x28 크기의 MNIST 손글씨 이미지는 흑백의 이미지로써 1x28x28 의 3차원으로 표현됩니다.

     

    32x32 크기의 CIFAR10 이미지는 컬러 이미지로써 3x32x32 의 3차원으로 표현됩니다.

     

     

    이러한 경우에 흑백의 이미지는 1x28x28 이므로 Gray Color 을 뜻하는 첫번째 차원인 1 은 Squeeze 의 대상이 됩니다.

    그래서 흑백 이미지에 Squeeze 함수를 적용하게 되면 1x28x28 인 Tensor 는 28x28 로 변경되게 됩니다.

    반면, 3x32x32 인 컬러 이미지는 squeeze 의 대상이 되지 않죠.

     

    squeeze, unsqueeze 함수 사용해보기.

     

    squeeze 함수의 사용법은 간단합니다.

    아래와 같이 torch.squeeze 함수의 인자로써 Tensor 를 입력하면 됩니다.

    아래 예시는 1x3x3 인 Tensor 에 squeeze 함수를 적용합니다.

    이 경우에 첫번째 차원인 1 이 Squeeze 함수가 적용되어 사라지게 됩니다.

    import torch 
    
    origin_tensor = torch.tensor([[[1,2,3], [4,5,6], [7,8,9]]])
    origin_tensor.shape
    # torch.Size([1, 3, 3])
    
    squeezed_tensor = torch.squeeze(origin_tensor)
    squeezed_tensor.shape
    #torch.Size([3, 3])

     

    Length 가 1인 차원이 여러개 존재한다면, Squeeze 함수에 의해서 Length 가 1 인 모든 차원은 소멸됩니다.

     

    극단적인 예시를 작성해보도록 하겠습니다.

    아래와 같이 10개의 차원을 가지는 테스트 텐서가 존재하고,

    이 텐서에 squeeze 함수를 적용하게 되면, 아래와 같이 크기가 1인 모든 차원이 제거됨을 알 수 있습니다.

    import torch
    
    tensor = torch.ones(1, 2, 1, 2, 1, 3, 4, 1, 1, 2)
    print(tensor.shape)
    # torch.Size([1, 2, 1, 2, 1, 3, 4, 1, 1, 2])
    
    print(tensor.squeeze().shape)
    # torch.Size([2, 2, 3, 4, 2])

     

     

    unsqueeze.

    unsqueeze 는 squeeze 의 동작 방식과 반대로 동작합니다.

    unsqueeze 는 크기가 1인 새로운 차원을 추가하는 방식으로 동작을 합니다.

    예를 들어 아래와 같은 코드에서 dim=0 인자를 통해서 unsqueeze 을 수행하게 되면 1번째 차원이 추가됩니다.

    이와 유사하게 dim = 1 을 적용하게 되면, 새로운 마지막 차원이 추가되죠.

    import torch
    
    one_tensor = torch.ones(2)
    # tensor([1., 1.])
    one_tensor.shape
    # torch.Size([2])
    
    torch.unsqueeze(one_tensor, dim=0).shape
    # torch.Size([1, 2])
    
    torch.unsqueeze(one_tensor, dim=1).shape
    # torch.Size([2, 1])

     

    CNN 에서 흑백 이미지를 다루는 경우에 Channel Size 가 1이 되게되는데, 이러한 경우에 squeeze 를 통해서 Channel 에 해당하는 차원을 제거할 수 있고,

    RNN 에서 Batch Size 나 Sequence Size 에 해당하는 차원을 추가하거나 지우는 방식으로 주로 사용되곤 합니다.

     

    반응형
Designed by Tistory.