四、Pytorch中user_tensordot,torch.tensordot和torch.einsum的速度对比(Python中的AI对比实验)
完整代码如下:
# -*- coding: UTF-8 -*-
# Author: Perry
# @Create Time: 2020-04-07 13:53
from managpu import GpuManager
my_gpu = GpuManager()
my_gpu.set_by_memory(1)
import timeit
import numpy as np
import torch
def user_tensordot(A: torch.Tensor, B: torch.Tensor):
size = A.shape[0] * A.shape[2]
A = A.permute(1, 3, 0, 2).reshape(-1, size)
B = B.permute(0, 1, 2, 3).reshape(size, -1)
return A.matmul(B).reshape(22, 23, 26, 27)
if __name__ == '__main__':
A = torch.randn(20, 22, 25, 23, device="cuda")
B = torch.randn(20, 25, 26, 27, device="cuda")
repeat = 10
number = 1000
A_time = timeit.repeat('torch.tensordot(A, B, dims=([0, 2], [0, 1]))', 'from __main__ import torch, A, B', repeat=repeat, number=number)
A_time = np.mean(A_time)
print("A_time: ", A_time)
B_time = timeit.repeat('user_tensordot(A, B)', 'from __main__ import user_tensordot, A, B', repeat=repeat, number=number)
B_time = np.mean(B_time)
print("B_time: ", B_time)
C_time = timeit.repeat('torch.einsum("abcd, acef->bdef", (A, B))', 'from __main__ import torch, A, B',
repeat=repeat, number=number)
C_time = np.mean(C_time)
print("C_time: ", C_time)
输出结果:
A_time: 0.06444346878852229
B_time: 0.07555614910379518
C_time: 0.11107458020269405
结论:自定义与官方库的 tensordot 效果相似,而 torch.einsum 性能则会差很多,几乎慢了一倍。