0902-用GAN生成動漫頭像

0902-用GAN生成動漫頭像

pytorch完整教程目錄://www.cnblogs.com/nickchen121/p/14662511.html

一、概述

本節將通過 GAN 實現一個生成動漫人物頭像的例子。

在日本的技術部落格網站上有個部落客,利用 DCGAN 從 20 萬張動漫頭像中學習,最終能夠利用程式自動生成動漫頭像。源程式是利用 Chainer 框架實現的,在這裡我們將嘗試利用 Pytorch 實現。

原始的圖片是從網站中採集的,並利用 OpenCV 截取頭像,處理起來非常麻煩。因此我們在這裡通過之乎用戶 何之源 爬取並經過處理的 5 萬張圖片,想要圖片的百度網盤鏈接的可以加我微信:chenyoudea。需要注意的是,這裡圖片的解析度是 3×96×96,而不是論文中的 3×64×64,因此需要相應地調整網路結構,使生成影像的尺寸為 96。

二、程式碼結構

下面我們首先來看下我們未來的一個程式碼結構。

checkpoints/  # 無程式碼,用來保存模型
imgs/  # 無程式碼,用來保存生成的圖片
data/  # 無程式碼,用來保存訓練所需要的圖片
main.py  # 訓練和生成
model.py  # 模型定義
visualize.py  # 可視化工具 visdom 的開發
requirement.txt  # 程式中用到的第三方庫
README.MD  # 說明

三、model.py

model.py 主要是用來定義生成器和判別器的。

3.1 生成器

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by //www.cnblogs.com/nickchen121/
# Datatime:2021/5/10 10:37
# Filename:model.py
# Toolby: PyCharm
from torch import nn


class NetG(nn.Module):
    """
    生成器定義
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器 feature map 數
        self.main = nn.Sequential(
            # 輸入是 nz 維度的雜訊,可以認識它是一個 nz*1*1 的 feature map
            # H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size
            # 以下面一行程式碼的ConvTranspose2d舉例(初始 H_{in}=1):H_{out} = (1-1)*1-2*0+4 = 4
            nn.ConvTranspose2d(opt.nz, ngf * 8, (4, 4), (1, 1), (0, 0), bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*8)*4*4,其中(ngf*8)是輸出通道數,4 為 H_{out} 是通過上述公式計算出來的

            # 以下面一行程式碼的ConvTranspose2d舉例(初始 H_{in}=4):H_{out} = (4-1)*2-2*1+4 =8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*4)*8*8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的輸出形狀是:(ngf*2)*16*16

            nn.ConvTranspose2d(ngf * 2, ngf, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf)*32*32

            nn.ConvTranspose2d(ngf, 3, (5, 5), (3, 3), (1, 1), bias=False),
            nn.Tanh()
            # 輸出形狀:3*96*96
        )

    def forward(self, inp):
        return self.main(inp)

從上述生成器的程式碼可以看出生成器的構建比較簡單,直接用 nn.Sequential 把上卷積、激活等操作拼接起來就行了。這裡稍微注意下 ConvTranspose2d 的使用,當 kernel size 為 4、stride 為 2、padding 為 1 時,根據公式 \(H_{out} = (H_{in}-1)*stride – 2*padding + kernel_size\),輸出尺寸剛好變成輸入的兩倍。

最後一層我們使用了 tanh 把輸出圖片的像素歸一化至 -1~1,如果希望歸一化到 0~1,可以使用 sigimoid 方法。

3.2 判別器

class NetD(nn.Module):
    """
    判別器定義
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 輸入 3*96*96
            nn.Conv2d(3, ndf, (5, 5), (3, 3), (1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf)*32*32

            nn.Conv2d(ndf, ndf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*2)*16*16

            nn.Conv2d(ndf * 2, ndf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*4)*8*8

            nn.Conv2d(ndf * 4, ndf * 8, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*8)*4*4

            nn.Conv2d(ndf * 8, 1, (4, 4), (1, 1), (0, 0), bias=False),
            nn.Sigmoid()  # 輸出一個數:概率
        )

    def forward(self, inp):
        return self.main(inp).view(-1)

從上述程式碼可以看到判別器和生成器的網路結構幾乎是對稱的,從卷積核大小到 padding、stride 等設置,幾乎一模一樣。

需要注意的是,生成器的激活函數用的是 ReLU,而判別器使用的是 LeakyReLU,兩者其實沒有太大的區別,這種選擇更多的是經驗的總結。

判別器的最終輸出是一個 0~1 的數,表示這個樣本是真圖片的概率。

四、參數配置

在開始寫訓練函數前,我們可以先配置模型參數

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by //www.cnblogs.com/nickchen121/
# Datatime:2021/5/11 15:14
# Filename:config.py
# Toolby: PyCharm
class Config(object):
    data_path = 'data/'  # 數據集存放路徑
    num_workers = 4  # 多進程載入數據所用的進程數
    image_size = 96  # 圖片尺寸
    batch_size = 256
    max_epoch = 200
    lr1 = 2e-4  # 生成器的學習率
    lr2 = 2e-4  # 判別器的學習率
    beta1 = 0.5  # Adam 優化器的 beta1 參數
    use_gpu = False  # 是否使用 GPU
    nz = 100  # 雜訊維度
    ngf = 64  # 生成器的 feature map 數
    ndf = 64  # 判別器的 feature map 數

    save_path = 'imgs/'  # 生成圖片保存路徑

    vis = True  # 是否使用 visdom 可視化
    env = 'GAN'  # visdom 的 env
    plot_every = 20  # 每隔 20 個 batch,visdom 畫圖一次

    debug_file = '/tmp/debuggan'  # 存在該文件則進入 debug 模式
    d_every = 1  # 每 1 個 batch 訓練一次判別器
    g_every = 5  # 每 5 個 batch 訓練一次生成器
    decay_everty = 10  # 每 10 個 epoch 保存一次模型
    netd_path = 'checkpoints/netd_211.pth'  # 預訓練模型
    netg_path = 'checkpoints/netg_211.pth'

    # 測試時用的參數
    gen_img = 'result.png'
    # 從 512 張生成的圖片路徑中保存最好的 64 張
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 雜訊的均值
    gen_std = 1  # 雜訊的方差
    
opt = Config()

上述這些都只是模型的默認參數,還可以利用 Fire 等工具通過命令行傳入,覆蓋默認值。

除此之外,還可以使用 opt.atrr,還可以利用 IDE/Python 提供的自動補全功能,十分方便。

上述的超參數大多是照搬 DCGAN 論文的默認值,這些默認值都是坐著經過大量的實驗,發現這些參數能夠更快地去訓練出一個不錯的模型。

五、數據處理

當我們下載完數據之後,需要把所有圖片放在一文件夾,然後把文件夾移動到 data 目錄下(並且要確保 data 下沒有其他的文件夾)。使用這種方法是為了能夠直接使用 pytorchvision 自帶的 ImageFolder 讀取圖片,而沒有必要自己寫一個 Dataset。

數據讀取和載入的程式碼如下所示。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by //www.cnblogs.com/nickchen121/
# Datatime:2021/5/12 09:43
# Filename:dataset.py
# Toolby: PyCharm
import torch as t
import torchvision as tv
from torch.utils.data import DataLoader

from config import opt

# 數據處理,輸出規模為 -1~1
transforms = tv.transforms.Compose([
    tv.transforms.Scale(opt.image_size),
    tv.transforms.CenterCrop(opt.image_size),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 載入數據集
dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.num_workers,
    drop_last=True
)

從上述程式碼中可以發現,用 ImageFolder 配合 DataLoader 載入圖片十分方便。

六、訓練

在訓練之前,我們還需要定義幾個變數:模型、優化器、雜訊等。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by //www.cnblogs.com/nickchen121/
# Datatime:2021/5/10 10:37
# Filename:main.py
# Toolby: PyCharm
import os
import ipdb
import tqdm
import fire
import torch as t
import torchvision as tv
from visualize import Visualizer
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter

from config import opt
from dataset import dataloader
from model import NetD, NetG



def train(**kwargs):
    # 定義模型
    netd = NetD()
    netg = NetG()
    # 定義網路
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定義優化器和損失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真圖片 label 為 1,假圖片 label 為 0,noises 為生成網路的輸入雜訊
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = vars(t.randn(opt.batch_size, opt.nz, 1, 1))

    # 如果使用 GPU 訓練,把數據轉移到 GPU 上
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

在載入預訓練模型的時候,最好指定 map_location。因為如果程式之前在 GPU 上運行,那麼模型就會被存成 torch.cuda.Tensor,這樣載入的時候會默認把數據載入到顯示記憶體上。如果運行該程式的電腦中沒有 GPU,則會報錯,因此指定 map_location 把 Tensor 默認載入到記憶體上,等有需要的時候再載入到顯示記憶體中。

下面開始訓練網路,訓練的步驟如下所示:

  1. 訓練判別器:
    • 固定生成器
    • 對於真圖片,判別器的輸出概率值儘可能接近 1
    • 對於生成器生成的圖片,判別器儘可能輸出 0
  2. 訓練生成器
    • 固定判別器
    • 生成器生成圖片,儘可能讓判別器輸出 1
  3. 返回第一步,循環交替訓練
    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.use_gpu:
                real_img = real_img.cuda()

            # 訓練判別器
            if (ii + 1) % opt.d_every == 0:
                optimizer_d.zero_grad()
                # 儘可能把真圖片判別為 1
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 儘可能把假圖片判別為 0
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根據照片生成假圖片
                fake_ouput = netd(fake_img)
                error_d_fake = criterion(fake_ouput, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

            # 訓練生成器
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                # 儘可能讓判別器把假圖片也判別為 1
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

            # 可視化

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 定義可視化窗口
                vis = Visualizer(opt.env)

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                global fix_fake_imgs
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、圖片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()

在上述訓練程式碼中,需要注意以下幾點:

  • 訓練生成器的時候,不需要調整判別器的參數;訓練判別器的時候,也不需要調整生成器的參數
  • 在訓練判別器的時候,需要對生成器生成的圖片用 detach 操作進行計算圖截斷,避免反向傳播把梯度傳到生成器中。因為在訓練判別器的時候我們不需要訓練生成器,也就不需要生成器的梯度。
  • 在訓練分類器的時候,需要反向傳播兩次,一次是希望把真圖片判為 1,一次是希望把假圖片判為 0.也可以把這個兩者的數據放到一個 batch 中,進行一次前向傳播和一次反向傳播即可。但是人們發現,在一個 batch 中只包含真圖片或者只包含假圖片的做法最好。
  • 對於假圖片,在訓練判別器的時候,我們希望它輸出為 0;而在訓練生成器的時候,我們希望它輸出為 1.因此可以看到一堆相互矛盾的程式碼:error_d_fake = criterion(fake_output,fake_labels)error_g = criterion(fake_output, true_labels)。其實這也很好理解,判別器希望能夠把假圖片判別為 fake_label,而生成器希望能把它判別為 true_label,判別器和生成器相互對抗提升。
  • 其中的 Visualize 模組類似於上一章自己的寫的模組,可以直接複製粘貼源碼中的程式碼。

七、隨機生成圖片

除了上述所示的程式碼外,還提供了一個函數,能載入預訓練好的模型,並且利用雜訊隨機生成圖片。

@t.no_grad()
def generate():
    # 定義雜訊和網路
    netg, netd = NetG(opt), NetD(opt)
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises)

    # 載入預訓練的模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))

    # 是否使用 GPU
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成圖片,並計算圖片在判別器的分數
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑選最好的某幾張
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])

    # 保存圖片
    tv.utils.save_image(t.stack(result), opt.gen_num, normalize=True, range=(-1, 1))

八、訓練模型並測試

完整的程式碼可以添加我微信:chenyoudea,其實上述程式碼已經很完整了,或者去github //github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN下載。

這裡假設你是擁有完整的程式碼,那麼準備好數據後,可以用下面的命令開始訓練:

python main.py train --gpu=True --vis=True --batch-size=256 --max-epoch=200

如果使用了 visdom,此時打開 //localhost:8097 就能看到生成的影像。

訓練完成後,我們就可以利用生成網路隨機生成動漫頭像,輸入命令如下:

python main.py generate --gen-img='result.5w.png' --gen-search-num=15000