Pytorch之permute函數
- 2020 年 4 月 1 日
- 筆記
1、主要作用:變換tensor維度
example:
import torch x = torch.randn(2, 3, 5) print(x.size()) print(x.permute(2, 0, 1).size()) >>>torch.Size([2, 3, 5]) >>>torch.Size([5, 2, 3])
2、介紹一下transpose與permute的異同:
同:都是對tensor維度進行轉置;
異:permute函數可以對任意高維矩陣進行轉置,但沒有torch.permute()這個調用方式
torch.randn(2,3,4,5).permute(3,2,0,1).shape >>>torch.Size([5, 4, 2, 3])
transpose只能操作2D矩陣的轉置,無法操作超過2個維度,所以要想實現多個維度的轉置,既可以用一次性的
permute,也可以多次使用transpose;
torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape >>>torch.Size([5, 4, 2, 3])
3、permute函數與contiguous、view函數的關聯
contiguous: view只能作用在contiguous的variable上,如果在view之前調用了transpose、permute等,就需要調用
contiguous()來返回一個contiguous的copy;
也就是說transpose、permute等操作會讓tensor變得在記憶體上不連續,因此要想view,就得讓tensor先連續;
解釋如下:有些tensor並不是佔用一整塊記憶體,而是由不同的數據塊組成,而tensor的view()操作依賴於記憶體是整塊的,這時只需要執行contiguous()這個函數,把tensor變成在記憶體中連續分布的形式;
判斷ternsor是否為contiguous,可以調用torch.Tensor.is_contiguous()函數:
import torch x = torch.ones(10, 10) x.is_contiguous() # True x.transpose(0, 1).is_contiguous() # False x.transpose(0, 1).contiguous().is_contiguous() # True
另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),與 numpy.reshape() 的功能類似,大致相當於 tensor.contiguous().view(),這樣就省去了對tensor做view()變換前,調用contiguous()的麻煩;
3、permute與view函數功能
import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # ——> torch.Size([1, 2, 3]) permuted=unpermuted.permute(2,0,1) print(permuted.size()) # ——> torch.Size([3, 1, 2]) view_test = unpermuted.view(1,3,2) print(view_test.size()) >>>torch.Size([1, 2, 3]) torch.Size([3, 1, 2]) torch.Size([1, 3, 2])