ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [pytorch] torch.cat 알아보기
    AI-ML 2023. 3. 12. 21:08
    728x90
    반응형

    - 목차

     

    키워드.

    • - pytorch
    • - concatenate

     

    들어가며.

    numpy 의 concatenate 와 Pandas 의 concat 함수처럼 pytorch 에도 cat 이라는 함수가 제공됩니다.

    위 함수들은 공통적으로 두 객체를 병합하는 기능을 제공합니다.

    예를 들어, 아래의 예시 코드는 numpy, pandas, pytorch 에서 두 객체들을 병합하는 예시입니다.

    import numpy as np
    import pandas as pd
    import torch
    
    a_value = [
        [1, 2, 3, 4]
    ]
    
    b_value = [
        [5, 6, 7, 8]
    ]
    
    print(np.concatenate([ 
    	np.array(a), 
        np.array(b) 
    ], axis=0))
    
    print(pd.concat([ 
        pd.DataFrame(a), 
        pd.DataFrame(b)
    ], axis=0))
    
    print(torch.cat([ 
        torch.tensor(a), 
        torch.tensor(b) 
    ], dim=0))
    -- numpy concatenate --
    [[1 2 3 4]
     [5 6 7 8]]
     
     
    -- pandas concat --
       0  1  2  3
    0  1  2  3  4
    0  5  6  7  8
    
    
    -- pytorch -- 
    tensor([[1, 2, 3, 4],
            [5, 6, 7, 8]])

     

    이처럼 cat 함수를 사용하여 pytorch 의 Tensor 들을 병합할 수 있구요.

    이번 글에서는 pytorch 의 cat 함수에 대해서 알아보려고 합니다.

     

    torch.cat

    pytorch 의 cat 함수는 두 Tensor 들을 병합합니다.

    dim 인자를 활용하여 병합하는 방향을 결정할 수 있는데요.

    다양한 예시와 함께 알아보도록 하겠습니다.

     

    dim = 0.

    cat 함수는 dim 이라는 인자를 가집니다.

    그리고 dim 인자는 병합의 방향을 결정합니다.

    dim 을 0 으로 설정하는 경우에는 Column 의 수 또는 feature 들의 갯수를 유지한 채로 Row 갯수가 증가합니다.

    즉, Tensor 의 Batch Size 가 증가한다고 생각하시면 될 것 같네요.

    예를 들어, 아래와 같이 (1, 4) 사이즈의 두 Tensor 를 병합합니다.

    결과는 (2, 4) 인 Tensor 가 생성되죠.

    즉, 0 차원의 방향으로 Tensor 를 추가하게 됩니다.

    import torch 
    
    a_tensor = torch.tensor([[1, 2, 3, 4]])
    b_tensor = torch.tensor([[5, 6, 7, 8]])
    
    print(a_tensor.shape)
    # torch.Size([1, 4])
    print(b_tensor.shape)
    # torch.Size([1, 4])
    
    a_b_tensor = torch.cat([a_tensor, b_tensor], dim=0)
    print(a_b_tensor)
    # tensor([[1, 2, 3, 4],
    #        [5, 6, 7, 8]])
    print(a_b_tensor.shape)
    # torch.Size([2, 4])

     

     

    dim = 1.

    cat 함수의 dim 인자를 1 로 설정하게 되면, 두번째 차원의 방향으로 Tensor 들이 추가됩니다.

    2차원의 Tensor 라고 한다면 Column 또는 Feature 들이 증가하게 됩니다.

    import torch 
    
    a_tensor = torch.tensor([[1, 2, 3, 4]])
    b_tensor = torch.tensor([[5, 6, 7, 8]])
    
    print(a_tensor.shape)
    # torch.Size([1, 4])
    print(b_tensor.shape)
    # torch.Size([1, 4])
    
    a_b_tensor = torch.cat([a_tensor, b_tensor], dim=1)
    print(a_b_tensor)
    # tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
    print(a_b_tensor.shape)
    # torch.Size([1, 8])

     

     

    dim = -1.

    cat 함수의 dim 인자를 -1 로 설정하게 된다면, 이는 가장 마지막 차원이 증가하는 방향으로 Tensor 들이 병합됩니다.

    즉, 2차원의 Tensor 라고 한다면 dim = -1 과 dim = 1 은 동일한 결과를 얻게 됩니다.

    import torch 
    
    a_tensor = torch.tensor([[1, 2, 3, 4]])
    b_tensor = torch.tensor([[5, 6, 7, 8]])
    
    print(a_tensor.shape)
    # torch.Size([1, 4])
    print(b_tensor.shape)
    # torch.Size([1, 4])
    
    a_b_tensor = torch.cat([a_tensor, b_tensor], dim=-1)
    print(a_b_tensor)
    # tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
    print(a_b_tensor.shape)
    # torch.Size([1, 8])

     

     

    반응형
Designed by Tistory.