Pytorch之permute函數

1、主要作用:變換tensor維度

example:

import torch  x = torch.randn(2, 3, 5)  print(x.size())  print(x.permute(2, 0, 1).size())    >>>torch.Size([2, 3, 5])  >>>torch.Size([5, 2, 3])

2、介紹一下transpose與permute的異同:

同:都是對tensor維度進行轉置;

異:permute函數可以對任意高維矩陣進行轉置,但沒有torch.permute()這個調用方式

torch.randn(2,3,4,5).permute(3,2,0,1).shape    >>>torch.Size([5, 4, 2, 3])

transpose只能操作2D矩陣的轉置,無法操作超過2個維度,所以要想實現多個維度的轉置,既可以用一次性的

permute,也可以多次使用transpose;

torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape  >>>torch.Size([5, 4, 2, 3])

3、permute函數與contiguous、view函數的關聯

contiguous: view只能作用在contiguous的variable上,如果在view之前調用了transpose、permute等,就需要調用

contiguous()來返回一個contiguous的copy;

也就是說transpose、permute等操作會讓tensor變得在記憶體上不連續,因此要想view,就得讓tensor先連續;

解釋如下:有些tensor並不是佔用一整塊記憶體,而是由不同的數據塊組成,而tensor的view()操作依賴於記憶體是整塊的,這時只需要執行contiguous()這個函數,把tensor變成在記憶體中連續分布的形式;

判斷ternsor是否為contiguous,可以調用torch.Tensor.is_contiguous()函數:

import torch  x = torch.ones(10, 10)  x.is_contiguous()                                 # True  x.transpose(0, 1).is_contiguous()                 # False  x.transpose(0, 1).contiguous().is_contiguous()    # True

另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),與 numpy.reshape() 的功能類似,大致相當於 tensor.contiguous().view(),這樣就省去了對tensor做view()變換前,調用contiguous()的麻煩;

3、permute與view函數功能

import torch  import numpy as np    a=np.array([[[1,2,3],[4,5,6]]])  unpermuted=torch.tensor(a)  print(unpermuted.size())              #  ——>  torch.Size([1, 2, 3])    permuted=unpermuted.permute(2,0,1)  print(permuted.size())                #  ——>  torch.Size([3, 1, 2])    view_test = unpermuted.view(1,3,2)  print(view_test.size())    >>>torch.Size([1, 2, 3])  torch.Size([3, 1, 2])  torch.Size([1, 3, 2])