pytorch訓練GAN時的detach()
我最近在學使用Pytorch寫GAN程式碼,發現有些程式碼在訓練部分細節有略微不同,其中有的人用到了detach()函數截斷梯度流,有的人沒用detch(),取而代之的是在損失函數在反向傳播過程中將backward(retain_graph=True),本文通過兩個 gan 的程式碼,介紹它們的作用,並分析,不同的更新策略對程式效率的影響。
這兩個 GAN 的實現中,有兩種不同的訓練策略:
- 先訓練判別器(discriminator),再訓練生成器(generator),這是原始論文Generative Adversarial Networks 中的演算法
- 先訓練generator,再訓練discriminator
為了減少網路垃圾,GAN的原理網上一大堆,我這裡就不重複贅述了,想要詳細了解GAN原理的朋友,可以參考我專題文章:神經網路結構:生成式對抗網路(GAN)。
需要了解的知識:
detach():截斷node反向傳播的梯度流,將某個node變成不需要梯度的Varibale,因此當反向傳播經過這個node時,梯度就不會從這個node往前面傳播。
更新策略
我們直接下面進入本文正題,即,在 pytorch 中,detach 和 retain_graph 是幹什麼用的?本文將藉助三段 GAN 的實現程式碼,來舉例介紹它們的作用。
先訓練判別器,再訓練生成器
策略一
我們分析循環中一個 step 的程式碼:
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真實標籤,都是1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0 # ######################## # 訓練判別器 # # ######################## real_imgs = imgs.to(device) # 真實圖片 z = torch.randn((imgs.shape[0], 100)).to(device) # 雜訊 gen_imgs = generator(z) # 從雜訊中生成假數據 pred_gen = discriminator(gen_imgs) # 判別器對假數據的輸出 pred_real = discriminator(real_imgs) # 判別器對真數據的輸出 optimizer_D.zero_grad() # 把判別器中所有參數的梯度歸零 real_loss = adversarial_loss(pred_real, valid) # 判別器對真實樣本的損失 fake_loss = adversarial_loss(pred_gen, fake) # 判別器對假樣本的損失 d_loss = (real_loss + fake_loss) / 2 # 兩項損失相加取平均 # 下面這行程式碼十分重要,將在正文著重講解 d_loss.backward(retain_graph=True) # retain_graph=True 十分重要,否則計算圖記憶體將會被釋放 optimizer_D.step() # 判別器參數更新 # ######################## # 訓練生成器 # # ######################## g_loss = adversarial_loss(pred_gen, valid) # 生成器的損失函數 optimizer_G.zero_grad() # 生成器參數梯度歸零 g_loss.backward() # 生成器的損失函數梯度反向傳播 optimizer_G.step() # 生成器參數更新
程式碼講解
鑒別器的損失函數d_loss是由real_loss和fake_loss組成的,而fake_loss又是noise經過generator來的。這樣一來我們對d_loss進行反向傳播,不僅會計算discriminator 的梯度還會計算generator 的梯度(雖然這一步optimizer_D.step()只更新 discriminator 的參數),因此下面在更新generator參數時,要先將generator參數的梯度清零,避免受到discriminator loss 回傳過來的梯度影響。
generator 的 損失在回傳時,同樣要經過 discriminator 網路才能傳遞迴自身(系統從輸入雜訊到 Discriminator 輸出,從頭到尾只有一次前向傳播,而有兩次反向傳播,故在第一次反向傳播時,鑒別器要設置 backward(retain graph=True),保持計算圖不被釋放。因為 pytorch 默認一個計算圖只計算一次反向傳播,反向傳播後,這個計算圖的記憶體就會被釋放,所以用這個參數控制計算圖不被釋放。因此,在回傳梯度時,同樣也計算了一遍 discriminator 的參數梯度,只不過這次 discriminator 的參數不更新,只更新 generator 的參數,即 optimizer_G.step()。同時,我們看到,下一個 step 首先將 discriminator 的梯度重置為 0,就是為了防止 generator loss 反向傳播時順帶計算的梯度對其造成影響(還有上一步 discriminator loss 回傳時累積的梯度)。
綜上,我們看到,為了完成一步參數更新,我們進行了兩次反向傳播,第一次反向傳播為了更新 discriminator 的參數,但多餘計算了 generator 的梯度。第二次反向傳播為了更新 generator 的參數,但是計算了 discriminator 的梯度,因此在寫一個step,需要立即清零discriminator梯度。
如果你實在看不懂,就照著這個形式寫程式碼就行了,反正形式都幫你們寫好了。
策略二
這種策略我遇到的比較多,也是先訓練鑒別器,再訓練生成器
鑒別器訓練階段,noise 從 generator 輸入,輸出 fake data,然後 detach 一下,隨著 true data 一起輸入 discriminator,計算 discriminator 損失,並更新 discriminator 參數。生成器訓練階段,把沒經過 detach 的 fake data 輸入到discriminator 中,計算 generator loss,再反向傳播梯度,更新 generator 的參數。這種策略,計算了兩次 discriminator 梯度,一次 generator 梯度。感覺這種比較符合先更新 discriminator 的習慣。缺點是,之前的 generator 生成的計算圖得保留著,直到 discriminator 更新完,再釋放。
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真實標籤,都是1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0 # ######################## # 訓練判別器 # # ######################## real_imgs = imgs.to(device) # 真實圖片 z = torch.randn((imgs.shape[0], 100)).to(device) # 雜訊 gen_imgs = generator(z) # 從雜訊中生成假數據 pred_gen = discriminator(gen_imgs.detach()) # 假數據detach(),判別器對假數據的輸出 pred_real = discriminator(real_imgs) # 判別器對真數據的輸出 optimizer_D.zero_grad() # 把判別器中所有參數的梯度歸零 real_loss = adversarial_loss(pred_real, valid) # 判別器對真實樣本的損失 fake_loss = adversarial_loss(pred_gen, fake) # 判別器對假樣本的損失 d_loss = (real_loss + fake_loss) / 2 # 兩項損失相加取平均 # 下面這行程式碼十分重要,將在正文著重講解 d_loss.backward() # retain_graph=True 十分重要,否則計算圖記憶體將會被釋放 optimizer_D.step() # 判別器參數更新 # ######################## # 訓練生成器 # # ######################## g_loss = adversarial_loss(pred_gen, valid) # 生成器的損失函數 optimizer_G.zero_grad() # 生成器參數梯度歸零 g_loss.backward() # 生成器的損失函數梯度反向傳播 optimizer_G.step() # 生成器參數更新
先訓練生成器,再訓練判別器
我們分析循環中一個 step 的程式碼:
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真實樣本的標籤,都是 1 fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成樣本的標籤,都是 0 z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 雜訊 real_imgs = Variable(imgs.type(Tensor)) # 真實圖片 # ######################## # 訓練生成器 # # ######################## optimizer_G.zero_grad() # 生成器參數梯度歸零 gen_imgs = generator(z) # 根據雜訊生成虛假樣本 g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 用真實的標籤+假樣本,計算生成器損失 g_loss.backward() # 生成器梯度反向傳播,反向傳播經過了判別器,故此時判別器參數也有梯度 optimizer_G.step() # 生成器參數更新,判別器參數雖然有梯度,但是這一步不能更新判別器 # ######################## # 訓練判別器 # # ######################## optimizer_D.zero_grad() # 把生成器損失函數梯度反向傳播時,順帶計算的判別器參數梯度清空 real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真樣本+真標籤:判別器損失 fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假樣本+假標籤:判別器損失 d_loss = (real_loss + fake_loss) / 2 # 判別器總的損失函數 d_loss.backward() # 判別器損失回傳 optimizer_D.step() # 判別器參數更新
為了更新生成器參數,用生成器的損失函數計算梯度,然後反向傳播,傳播圖中經過了判別器,根據鏈式法則,不得不順帶計算一下判別器的參數梯度,雖然在這一步不會更新判別器參數。反向傳播過後,noise 到 fake image 再到 discriminator 的輸出這個前向傳播的計算圖就被釋放掉了,後面也不會再用到。
接著更新判別器參數,此時注意到,我們輸入判別器的是兩部分,一部分是真實數據,另一部分是生成器的輸出,也就是假數據。注意觀察細節,在判別器前向傳播過程,輸入的假數據被 detach 了,detach 的意思是,這個數據和生成它的計算圖「脫鉤」了,即梯度傳到它那個地方就停了,不再繼續往前傳播(實際上也不會再往前傳播了,因為 generator 的計算圖在第一次反向傳播過後就被釋放了)。因此,判別器梯度反向傳播,就到它自己身上為止。
因此,比起第一種策略,這種策略要少計算一次 generator 的所有參數的梯度,同時,也不必刻意保存一次計算圖,佔用不必要的記憶體。
但需要注意的是,在第一種策略中,noise 從 generator 輸入,到 discriminator 輸出,只經歷了一次前向傳播,discriminator 端的輸出,被用了兩次,一次是計算 discriminator 的損失函數,另一次是計算 generator 的損失函數。
而在第這種策略中,noise 從 generator 輸入,到discriminator 輸出,計算 generator 損失,回傳,這一步更新了 generator 的參數,並釋放了計算圖。下一步更新 discriminator 的參數時,generator 的輸出經過 detach 後,又通過了一遍 discriminator,相當於,generator 的輸出前後兩次通過了 discriminator ,得到相同的輸出。顯然,這也是冗餘的。
總結
綜上,這兩段程式碼各有利弊:
第一段程式碼,好處是 noise 只進行了一次前向傳播,缺點是,更新 discriminator 參數時,多計算了一次 generator 的梯度,同時,第一次更新 discriminator 需要保留計算圖,保證算 generator loss 時計算圖不被銷毀。
第三段程式碼,好處是通過先更新 generator ,使更新後的前向傳播計算圖可以放心被銷毀,因此不用保留計算圖佔用記憶體。同時,在更新 discriminator 的時候,也不會像上面的那段程式碼,計算冗餘的 generator 的梯度。缺點是,在 discriminator 上,對 generator 的輸出算了兩次前向傳播,第二次又產生了新的計算圖(但比第一次的小)。
一個多計算了一次 generator 梯度,一個多計算一次 discriminator 前向傳播。因此,兩者差別不大。如果 discriminator 比generator 複雜,那麼應該採取第一種策略,如果 discriminator 比 generator 簡單,那麼應該採取第三種策略,通常情況下,discriminator 要比 generator 簡單,故如果效果差不多盡量採取第三種策略。
但是第三種先更新generator,再更新 discriminator 總是給人感覺怪怪得,因為 generator 的更新需要 discriminator 提供準確的 loss 和 gradient,否則豈不是在瞎更新?
但是策略三,馬上用完馬上釋放。綜合來說,還是策略三最好,策略二其次,策略一最差(差在多計算一次 generator gradient 上,而通常多計算一次 generator gradient 的運算量比多計算一次 discriminator 前向傳播的運算量大),因此,detach 還是很有必要的。
參考
Pytorch: detach 和 retain_graph
使用PyTorch進行GAN訓練時對於梯度截斷的思考.detach()