使用PyTorch進行情侶幸福度測試指南

  • 2019 年 10 月 6 日
  • 筆記

DeepConnection模型框架

電腦視覺–影像和影片數據分析是深度學習目前最火的應用領域之一。因此,在學習深度學習的同時嘗試運用某些電腦視覺技術做些有趣的事情會很有意思,也會讓你發現些令人吃驚的事實。長話短說,我的搭檔(Maximiliane Uhlich)和我決定將深度學習應用於浪漫情侶的形象分類上,因為Maximiliane是一位關係研究員和情感治療師。具體來說,我們想知道我們是否可以準確地判斷影像或影片中描繪的情侶是否對他們的關係感到滿意?事實證明,我們可以!我們的最終模型(我們稱之為DeepConnection)分類準確率接近97%,能夠準確地區分幸福與不幸福的情侶。大家可以在我們的論文預覽鏈接[1]里閱讀完整介紹,上圖是我們為這個任務設計的框架草圖。

在數據集收集方面,我們使用這個Python腳本[2]進行網頁數據抽取(webscraping)來獲取幸福和不幸福的情侶數據。最後,我們整理出了大約包含1000張影像的訓練集。這並不是特別多,所以我們使用數據增強與遷移學習來增強我們模型在數據集上的表現。數據增強–影像方向的微小變化,色調和色彩強度以及許多其他因素都會增強模型的泛化能力,從而避免學習一些不相關資訊。例如,如果數據中幸福夫妻的影像平均比不幸福夫妻的影像更亮,我們並不希望我們的模型映射這種關聯。我們使用了強大的ImgAug庫[3]進行了相當多策略的數據擴充,以確保我們模型的魯棒性。基本上對於每個批次的每個影像,我們都至少應用多種數據增強技術。下圖是一張圖片應用了48種數據增強策略的示例。

影像增強後數據示例

我們決定使用ResNet模型作為DeepConnection的基礎網路,在大型數據集ImageNet上預先訓練。通過預訓練,模型已經具有了一定的識別能力。我們所有的模型都借用PyTorch實現,我們使用Google Colab上的免費GPU資源進行訓練和測試。這個基礎模型本身已經具備了良好的分類能力,但我們決定更進一步,用空間金字塔池化層(SPP)[4] 替換ResNet-34基礎模型的最後一個自適應池模組。這裡,處理後的影像數據被分成不同數量的正方形,並且僅傳遞最大值以進行進一步分析(最大池化)。這使得模型可以專註於重要的特徵,使其對不同大小的影像具有魯棒性,並且不受影像擾動的影響。之後,我們放置了一個均值變換(PMT)層[5],用數學函數轉換數據以引入非線性,使得DeepConnection可以從數據中捕獲更複雜的關係。這兩個模組均提高了我們的分類準確度,我們在單獨的驗證集上得到了大約97%準確率。SPP / PMT和後續分類層的程式碼如下所示:

class SPP(nn.Module):    def __init__(self):      super(SPP, self).__init__()        ## features incoming from ResNet-34 (after SPP/PMT)      self.lin1 = nn.Linear(2*43520, 100)        self.relu = nn.ReLU()      self.bn1 = nn.BatchNorm1d(100)      self.dp1 = nn.Dropout(0.5)      self.lin2 = nn.Linear(100, 2)      def forward(self, x):      # SPP      x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1])        # PMT      x_1 = torch.sign(x)*torch.log(1 + abs(x))      x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2      x = torch.cat((x_1, x_2), dim = 1)        # fully connected classification part      x = self.lin1(x)      x = self.bn1(self.relu(x))        #1      x1 = self.lin2(self.dp1(x))      #2      x2 = self.lin2(self.dp1(x))      #3      x3 = self.lin2(self.dp1(x))      #4      x4 = self.lin2(self.dp1(x))      #5      x5 = self.lin2(self.dp1(x))      #6      x6 = self.lin2(self.dp1(x))      #7      x7 = self.lin2(self.dp1(x))      #8      x8 = self.lin2(self.dp1(x))        x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0)        return x  

仔細觀察程式碼可以看出,最終分類層上有八個變種。看似浪費了算力實際上恰恰相反。這個概念是最近提出的,叫做multi-sample dropout(多樣本隨機丟棄),它在訓練期間顯著加速了收斂[6]。它基本上是防止模型學習虛假關係(過度擬合)和試圖不丟棄丟失掩碼中的資訊之間的折衷。

我們在項目中對這個方法進行了其他一些調整優化,具體參看我們在GitHub放出的項目程式碼[7]以獲取更多資訊。簡單地提一下:我們使用混合精度(使用Apex庫[8]實現)訓練模型,以大大降低記憶體使用率,使用早停(earlystopping)來防止過度擬合,並根據餘弦函數進行學習率退火。

在達到令人滿意的分類準確度(具有相應高的召回率和精確度)後,我們想知道我們是否可以從DeepConnection執行的分類中學到一些東西。因此,我們嘗試模型解釋性探索並使用梯度加權類激活映射技術(Grad-CAM)進行分析[9]。基本地,Grad-CAM獲取最終卷積層的輸入梯度以確定顯著區域,其可以被視為原始影像之上的上取樣熱圖。具體實現與可視化結果如下:

熱度圖對比

## from https://github.com/eclique/pytorch-gradcam/blob/master/gradcam.ipynb    def GradCAM(img, c, features_fn, classifier_fn):      feats = modulelist_conv(img.cuda().half())      feats = feats.cuda()      _, N, H, W = feats.size()        out = modulelist_fc(feats)      c_score = out[0, c]      grads = torch.autograd.grad(c_score, feats)      w = grads[0][0].mean(-1).mean(-1)        sal = torch.matmul(w, feats.view(N, H*W))      sal = sal.view(H, W).cpu().detach().numpy()      sal = np.maximum(sal, 0)        return sal  

我們在論文中對此進行了進一步討論,並將其嵌入到了現有的心理學研究中,但DeepConnection似乎主要關注面部區域。從研究的角度來看,這很有意義,因為面部表情會傳達溝通和情感。除了Grad-CAM獲得的視覺感知之外,我們還想看看我們是否可以通過模型解釋得出實際特徵。為此,我們創建了激活狀態圖,以顯示最終分類層的哪些神經元被哪些給定影像區域激活。

不同幸福程度代表性激活狀態圖

與其他模型相比,DeepConnection還學習到了代表不幸福的特徵,並不僅僅將缺乏代表幸福的特徵的分類為不幸福。但是,我們需要進一步的研究才能將這些特徵實際映射到人類行為可解釋性方面。我們還嘗試過在未知的情侶影片幀上使用DeepConnection,效果非常好。

總體而言,該模型的穩健性是其強大優勢之一。準確的分類同樣適用於同性戀伴侶不同膚色人種除情侶外包含其他人的影片幀中不能完整顯示情侶人臉的影片幀中等等。對於影像中存在其他人的情況,DeepConnection甚至可以識別其他人是否感到滿意,但仍然將其預測集中在這對情侶身上。

除了進一步的模型解釋之外,下一步的工作將是使用更大的訓練數據集,從而訓練更複雜的模型。使用DeepConnection作為情侶治療師的助手將會很有意思,可以在會話期間或之後對情侶的當前關係狀態進行實時回饋。此外,我建議您與女票/男票一起輸入你們的合照,看看DeepConnection對你們的關係有何看法!希望這會是一個好的開始!

1: https://psyarxiv.com/df25j/ 2: https://github.com/Bribak/DeepConnection 3: https://github.com/aleju/imgaug 4: https://arxiv.org/abs/1406.4729 5: https://www.sciencedirect.com/science/article/pii/S0031320318304503 6: https://arxiv.org/abs/1905.09788 7: https://github.com/Bribak/DeepConnection 8: https://github.com/NVIDIA/apex 9: https://arxiv.org/abs/1610.02391