-
[PyTorch] squeeze, unsqueeze 알아보기AI-ML 2024. 5. 21. 06:22728x90반응형
- 목차
들어가며.
PyTorch 는 Tensor 의 차원을 늘리거나 줄일 수 있는 함수를 제공합니다.
이러한 함수의 이름은 squeeze 와 unsqueeze 함수인데요.
squeeze 라는 영단어의 뜻이 "(특히 손가락으로 꼭) 짜다[쥐다]", "(무엇에서 액체를) 짜내다[짜다]" 라고 하는데요.
이런 의미처럼 squeeze 함수를 통해서 Tensor 의 차원을 줄이고, unsqueeze 함수를 활용하여 차원을 늘릴 수 있습니다.
squeeze 함수와 unsqueeze 함수는 무턱대로 차원을 변형하지는 않습니다.
차원의 Length 가 1 인 경우가 늘리거나 줄어드는 대상이 됩니다.
예를 들어, 흑백의 이미지와 RGB 컬러의 이미지가 존재한다고 가정하겠습니다.
일반적인 이미지는 Width 와 Height 인 크기 정보가 존재합니다.
즉, 이미지는 기본적으로 2차원의 데이터입니다.
여기에 흑백의 이미지의 경우에는 Gray Scale 의 차원이 하나가 추가됩니다.
반면 컬러 이미지의 경우에는 Red-Green-Blue 인 차원이 추가되죠.
이를 수치적으로 표현하면 아래와 같습니다.
28x28 크기의 MNIST 손글씨 이미지는 흑백의 이미지로써 1x28x28 의 3차원으로 표현됩니다.
32x32 크기의 CIFAR10 이미지는 컬러 이미지로써 3x32x32 의 3차원으로 표현됩니다.
이러한 경우에 흑백의 이미지는 1x28x28 이므로 Gray Color 을 뜻하는 첫번째 차원인 1 은 Squeeze 의 대상이 됩니다.
그래서 흑백 이미지에 Squeeze 함수를 적용하게 되면 1x28x28 인 Tensor 는 28x28 로 변경되게 됩니다.
반면, 3x32x32 인 컬러 이미지는 squeeze 의 대상이 되지 않죠.
squeeze, unsqueeze 함수 사용해보기.
squeeze 함수의 사용법은 간단합니다.
아래와 같이 torch.squeeze 함수의 인자로써 Tensor 를 입력하면 됩니다.
아래 예시는 1x3x3 인 Tensor 에 squeeze 함수를 적용합니다.
이 경우에 첫번째 차원인 1 이 Squeeze 함수가 적용되어 사라지게 됩니다.
import torch origin_tensor = torch.tensor([[[1,2,3], [4,5,6], [7,8,9]]]) origin_tensor.shape # torch.Size([1, 3, 3]) squeezed_tensor = torch.squeeze(origin_tensor) squeezed_tensor.shape #torch.Size([3, 3])
Length 가 1인 차원이 여러개 존재한다면, Squeeze 함수에 의해서 Length 가 1 인 모든 차원은 소멸됩니다.
극단적인 예시를 작성해보도록 하겠습니다.
아래와 같이 10개의 차원을 가지는 테스트 텐서가 존재하고,
이 텐서에 squeeze 함수를 적용하게 되면, 아래와 같이 크기가 1인 모든 차원이 제거됨을 알 수 있습니다.
import torch tensor = torch.ones(1, 2, 1, 2, 1, 3, 4, 1, 1, 2) print(tensor.shape) # torch.Size([1, 2, 1, 2, 1, 3, 4, 1, 1, 2]) print(tensor.squeeze().shape) # torch.Size([2, 2, 3, 4, 2])
unsqueeze.
unsqueeze 는 squeeze 의 동작 방식과 반대로 동작합니다.
unsqueeze 는 크기가 1인 새로운 차원을 추가하는 방식으로 동작을 합니다.
예를 들어 아래와 같은 코드에서 dim=0 인자를 통해서 unsqueeze 을 수행하게 되면 1번째 차원이 추가됩니다.
이와 유사하게 dim = 1 을 적용하게 되면, 새로운 마지막 차원이 추가되죠.
import torch one_tensor = torch.ones(2) # tensor([1., 1.]) one_tensor.shape # torch.Size([2]) torch.unsqueeze(one_tensor, dim=0).shape # torch.Size([1, 2]) torch.unsqueeze(one_tensor, dim=1).shape # torch.Size([2, 1])
CNN 에서 흑백 이미지를 다루는 경우에 Channel Size 가 1이 되게되는데, 이러한 경우에 squeeze 를 통해서 Channel 에 해당하는 차원을 제거할 수 있고,
RNN 에서 Batch Size 나 Sequence Size 에 해당하는 차원을 추가하거나 지우는 방식으로 주로 사용되곤 합니다.
반응형'AI-ML' 카테고리의 다른 글
[PyTorch] requires_grad 알아보기 (0) 2024.05.26 [scikit-surprise] SVD 모델 추론하기 (0) 2024.05.26 [ scikit-surprise ] SVD Regularization Terms 알아보기 (0) 2024.05.18 [scikit-surprise] SVD 모델 생성하기 (0) 2024.05.18 [scikit-surprise] Dataset 이해하기 (0) 2024.05.14