四、Pytorch中user_tensordot,torch.tensordot和torch.einsum的速度對比(Python中的AI對比實驗)

完整程式碼如下:

# -*- coding: UTF-8 -*-

# Author: Perry
# @Create Time: 2020-04-07 13:53

from managpu import GpuManager
my_gpu = GpuManager()
my_gpu.set_by_memory(1)

import timeit
import numpy as np

import torch


def user_tensordot(A: torch.Tensor, B: torch.Tensor):
    size = A.shape[0] * A.shape[2]
    A = A.permute(1, 3, 0, 2).reshape(-1, size)
    B = B.permute(0, 1, 2, 3).reshape(size, -1)
    return A.matmul(B).reshape(22, 23, 26, 27)


if __name__ == '__main__':
    A = torch.randn(20, 22, 25, 23, device="cuda")
    B = torch.randn(20, 25, 26, 27, device="cuda")

    repeat = 10
    number = 1000

    A_time = timeit.repeat('torch.tensordot(A, B, dims=([0, 2], [0, 1]))', 'from __main__ import torch, A, B', repeat=repeat, number=number)
    A_time = np.mean(A_time)
    print("A_time: ", A_time)

    B_time = timeit.repeat('user_tensordot(A, B)', 'from __main__ import user_tensordot, A, B', repeat=repeat, number=number)
    B_time = np.mean(B_time)
    print("B_time: ", B_time)

    C_time = timeit.repeat('torch.einsum("abcd, acef->bdef", (A, B))', 'from __main__ import torch, A, B',
                           repeat=repeat, number=number)
    C_time = np.mean(C_time)
    print("C_time: ", C_time)

輸出結果:

A_time:  0.06444346878852229
B_time:  0.07555614910379518
C_time:  0.11107458020269405

結論:自定義與官方庫的 tensordot 效果相似,而 torch.einsum 性能則會差很多,幾乎慢了一倍。