Pytorch 中 tensor的維度拼接

  • 2022 年 7 月 14 日
  • 筆記

torch.stack() 和 torch.cat() 都可以按照指定的維度進行拼接,但是兩者也有區別,torch.satck() 是增加新的維度進行堆疊,即其維度拼接後會增加一個維度;而torch.cat() 是在原維度上進行堆疊,即其維度拼接後的維度個數和原來一致。具體說明如下:

torch.stack(input,dim)

input: 待拼接的張量序列組(list or tuple),拼接的tensor的維度必須要相等,即tensor1.shape = tensor2.shape

dim: 在哪個新增的維度上進行拼接,不能超過拼接後的張量數據的維度大小,默認為 0

import torch 

x1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])

print(torch.stack((x1,x2),dim=0).shape)
print(torch.stack((x1,x2),dim=1).shape)
print(torch.stack((x1,x2),dim=2).shape)

print(torch.stack((x1,x2),dim=0))
print(torch.stack((x1,x2),dim=1))
print(torch.stack((x1,x2),dim=2))

>> torch.Size([2, 3, 3])		# 2 表示是有兩個tensor的拼接,且在第一個維度的位置拼接
>> torch.Size([3, 2, 3])
>> torch.Size([3, 3, 2])
>> tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],
         
        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
>> tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]])
>> tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])

torch.cat(input, dim)

input: 待拼接的張量序列組(list or tuple),拼接的tensor的維度必須要相等,即tensor1.shape = tensor2.shape

dim: 在哪個已存在的維度上進行拼接,不能超過拼接後的張量數據的維度大小(即原來的維度大小),默認為 0

import torch

x1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])

print(torch.cat((x1,x2),dim=0).shape)
print(torch.cat((x1,x2),dim=1).shape)

print(torch.cat((x1,x2),dim=0))
print(torch.cat((x1,x2),dim=1))

>> torch.Size([6, 3])
>> torch.Size([3, 6])

>> tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])
>> tensor([[ 1,  2,  3, 10, 20, 30],
        [ 4,  5,  6, 40, 50, 60],
        [ 7,  8,  9, 70, 80, 90]])