Pytorch 中 tensor的維度拼接
- 2022 年 7 月 14 日
- 筆記
torch.stack() 和 torch.cat() 都可以按照指定的維度進行拼接,但是兩者也有區別,torch.satck() 是增加新的維度進行堆疊,即其維度拼接後會增加一個維度;而torch.cat() 是在原維度上進行堆疊,即其維度拼接後的維度個數和原來一致。具體說明如下:
torch.stack(input,dim)
input: 待拼接的張量序列組(list or tuple),拼接的tensor的維度必須要相等,即tensor1.shape = tensor2.shape
dim: 在哪個新增的維度上進行拼接,不能超過拼接後的張量數據的維度大小,默認為 0
import torch
x1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((x1,x2),dim=0).shape)
print(torch.stack((x1,x2),dim=1).shape)
print(torch.stack((x1,x2),dim=2).shape)
print(torch.stack((x1,x2),dim=0))
print(torch.stack((x1,x2),dim=1))
print(torch.stack((x1,x2),dim=2))
>> torch.Size([2, 3, 3]) # 2 表示是有兩個tensor的拼接,且在第一個維度的位置拼接
>> torch.Size([3, 2, 3])
>> torch.Size([3, 3, 2])
>> tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
>> tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
>> tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
torch.cat(input, dim)
input: 待拼接的張量序列組(list or tuple),拼接的tensor的維度必須要相等,即tensor1.shape = tensor2.shape
dim: 在哪個已存在的維度上進行拼接,不能超過拼接後的張量數據的維度大小(即原來的維度大小),默認為 0
import torch
x1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.cat((x1,x2),dim=0).shape)
print(torch.cat((x1,x2),dim=1).shape)
print(torch.cat((x1,x2),dim=0))
print(torch.cat((x1,x2),dim=1))
>> torch.Size([6, 3])
>> torch.Size([3, 6])
>> tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
>> tensor([[ 1, 2, 3, 10, 20, 30],
[ 4, 5, 6, 40, 50, 60],
[ 7, 8, 9, 70, 80, 90]])