GPU隨機取樣速度比較

技術背景

隨機取樣問題,不僅僅只是一個統計學/離散數學上的概念,其實在工業領域也都有非常重要的應用價值/潛在應用價值,具體應用場景我們這裡就不做贅述。本文重點在於在不同平台上的取樣速率,至於另外一個重要的參數檢驗速率,這裡我們先不做評估。因為在Jax中直接支援vmap的操作,而numpy的原生函數大多也支援了向量化的運算,兩者更像是同一種演算法的不同實現。所以對於檢驗的場景,兩者的速度區別更多的也是在硬體平台上。

隨機取樣示例

關於Jax的安裝和基本使用方法,讀者可以自行參考Jax的官方文檔,需要注意的是,Jax有CPU、GPU和TPU三個版本,如果需要使用其GPU版本的功能,還需要依賴於jaxlib,另外最好是指定安裝對應的CUDA版本,這都是安裝過程中所踩過的一些坑。最後如果安裝的不是GPU的版本,運行Jax腳本的時候會有相關的提示說明。

隨機取樣,可以是針對一個給定的連續函數,也可以針對一個離散化的列表,但是為了更好的擴展性,一般問題都會轉化成先獲取均勻的隨機分布,再轉化成其他函數形式的分布,如正態分布等。所以這裡我們更加的是關注下均勻分布函數的效率:

import numpy as np
import time
import jax.random as random
key = random.PRNGKey(0)

print ('An small example of numpy sampler: \n{}'.format(np.random.uniform(low=0,high=1,size=5)))
print ('An small example of jax sampler: \n{}'.format(random.uniform(key,shape=(5,),minval=0, maxval=1)))

data_size = 400000000
time0 = time.time()
s = np.random.uniform(low=0,high=1,size=data_size)
print ('The numpy time cost is: {}s'.format(time.time()-time0))

time1 = time.time()
v = random.uniform(key,shape=(data_size,),minval=0, maxval=1)
print ('The jax time cost is: {}s'.format(time.time()-time1))

執行結果如下:

An small example of numpy sampler: 
[0.33654613 0.20267496 0.86859762 0.14940831 0.30321738]
An small example of jax sampler: 
[0.57450044 0.09968603 0.39316022 0.8941783  0.59656656]
The numpy time cost is: 3.6664984226226807s
The jax time cost is: 0.10985755920410156s

同樣是在生成雙精度浮點數的情況下,我們可預期的GPU的速率在數據長度足夠大的情況下一定是會更快的,這個運算結果也佐證了這個說法。

總結概要

關於工業領域中可能使用到的隨機取樣,更多的是這樣的一個場景:給定一個連續或者離散的分布,然後進行大規模的連續取樣,取樣的同時需要對每一個得到的樣點進行分析打分,最終在這大規模的取樣過程中,有可能被使用到的樣品可能只有其中的幾份。那麼這樣的一個抽象問題,就非常適合使用分散式的多GPU硬體架構來實現。

版權聲明

本文首發鏈接為://www.cnblogs.com/dechinphy/p/sampler.html

作者ID:DechinPhy

更多原著文章請參考://www.cnblogs.com/dechinphy/

打賞專用鏈接://www.cnblogs.com/dechinphy/gallery/image/379634.html

騰訊雲專欄同步://cloud.tencent.com/developer/column/91958