論文閱讀——Non-Local Neural Networks

  • 2020 年 6 月 14 日
  • AI

一、摘要

當前深度學習中用於捕獲長距離依賴關係的主要方式——堆疊卷積或者使用循環模塊,均為局部鄰域處理組件。本文提出了另一種捕獲長距離依賴關係的通用組件——非局部處理模塊。
【待完善】

二、Non-Local Neural Network

2.1 回顧Non local means

  Non local means filter,即非局部均值濾波算法,是圖像去噪領域一個非常有名的算法。我讀研一時有個大四的學妹在做畢設,問我說她想找個傳統去噪算法作為她跑的DL代碼的對比方法,我就給她推薦了NLM(當時我也沒聽過幾個傳統去噪算法哈哈,所以應該算是大名鼎鼎了)。

  NLM的思路非常直觀,首先其出發點在於:在去噪任務中,對圖像的每個小塊,如果能用和其相同性質的其他小塊加權平均(去噪任務主流的方法是在不改變原始輸入信號的情況下通過加權平均,或者說濾波,來去除噪聲),會對當前小塊噪聲抑制起到較好的效果。這樣相同性質的小塊越多,加權平均之後的效果就越好。

  因此,NLM算法簡單來說就是兩個步驟:一是確定當前小塊與圖像所有位置的小塊之間的接近程度,即所有小塊與當前小塊加權平均時的權值;二是將各個小塊的像素值與當前小塊按照該權值,加權平均之後作為去噪的結果。算法示意圖如Fig. 1:

Fig. 1 非局部均值濾波算法

  Fig. 2展示了對一副大圖,以中心像素為基準,圖像所有像素的權重可視化結果。可以看出與中心像素所在鄰域性質相似的區域獲得的權重明顯高一些。

Fig. 2 NLM算法中各像素與中心像素的接近程度可視化

  NLM的模型如下:
NL[v](i) = \sum_{j\in{I}}w(i,j)v(j) \tag{1},其中v代表圖像,i,j代表位置下標;w表示權重。非局部均值濾波算法中具體使用的w(i,j)為:
w(i,j)=\frac{1}{Z(i)}e^{-\frac{||v(N_i)-v(N_j)||^2_{2,a}}{h^2}} \tag{2}Z(i)為歸一化因子,其值為當前小塊和圖像所有小塊權重之和:
Z(i)=\sum_je^{-\frac{||v(N_i)-v(N_j)||^2_{2,a}}{h^2}} \tag{3}

  這裡着重記錄一下w(i,j)。求i,j兩位置的權重w也就是衡量圖像在這兩個位置的相似度,最簡單的方法是直接評估兩位置對應像素灰度值的差距。考慮到使用鄰域比單獨的像素可靠一些,鄰域相似度高則認為這兩個像素的相似度高,故論文中用比較兩個像素所在鄰域來代替像素值直接比較。

  衡量兩個圖像塊相似度最常用的方法是計算他們之間的歐氏距離,不過在求歐式距離的時候,不同位置的像素的權重應當是不一樣的,距離塊的中心越近的像素,權重越大,距離中心越遠,權重越小。這個權重可以看作一個符合高斯分佈的kernel。(不過實際計算中考慮到計算量的問題,常常採用均勻分佈的權重。另外原始NLM算法速度很慢,後續又出了很多優化版本,這裡不做討論。)

   聯想:在harris特徵匹配算法中,計算像素之間相似程度也用到鄰域之間的相似度,不過計算的是歸一化互相關係數;sift匹配則是對鄰域提取到的特徵,利用餘弦相似定理(歸一化的dot product)來計算特徵相似程度。

  通過下面一段matlab代碼,可以看出NLM具體計算過程:

 for r=rmin:1:rmax
     for s=smin:1:smax                                
         if(r==i1 && s==j1) continue; end;
         W2= input2(r-f:r+f , s-f:s+f);                
         d = sum(sum(kernel.*(W1-W2).*(W1-W2)));
         w=exp(-d/h);                 
         if w>wmax                
             wmax=w;                   
         end
         sweight = sweight + w;
         average = average + w*input2(r,s);                           
     end 
 end

其中kernel由以下代碼生成:

function [kernel] = make_kernel(f)              
kernel=zeros(2*f+1,2*f+1);   
for d=1:f    
  value= 1 / (2*d+1)^2 ;    
  for i=-d:d
    for j=-d:d
        kernel(f+1-i,f+1-j)= kernel(f+1-i,f+1-j) + value ;
    end
  end
end
kernel = kernel ./ f;   

2.2 從NLM到NLNN

  上面一部分回顧了經典的NLM算法。另外之所以去細看NLM的論文,實際上就是因為在看Non-local Neural Networks論文時對Fig. 3這張論文截圖有兩個疑問:

  1. 為什麼作者在選擇f(x_i, x_j)時說「一個自然的選擇是使用高斯函數」
  2. 作者說使用”gaussian function”但是給出的式子只是一個矩陣乘法之後的指數函數,如何代表高斯?
Fig. 3 論文中提到的gaussian function可能並不嚴謹?

  我自己先嘗試對這2個問題進行解答:

  首先使用exponential function並不是gaussian的體現,按照NLM論文的說法,指數函數的作用其實就是為了讓歐式距離較大位置的權重能夠快速衰減到0附近,也即”fast decay”。(值得注意的是,後續還出現了不少針對nlm中的指數kernel進行優化的論文。)

  然後真正體現gaussian的是上面代碼段中在兩個鄰域計算歐式距離時加入的gaussian kernel衰減矩陣,其作用在於,按照鄰域中像素距離中心的遠近來控制求鄰域間歐式距離時的像素權重,即||v(N_i)-v(N_j)||^2_{2,a}代表gaussian版本的歐式距離。

  從這個角度來說, NLNN(Non Local Neural Network)論文中提到的gaussian function貌似並不是真的用了gaussian衰減,因為NLNN在比較i,j兩個位置時,甚至並未通過比較任何形式的鄰域N(i), N(j)來確定該j位置與i的接近程度,而僅僅是直接比較輸入信號在i,j兩個位置對應向量或者向量的embedding的距離測度。

  具體來說,對於NLM,輸入信號為2D圖像,通過計算以i,j像素為中心的鄰域相似度來決定位置j處的權重;對於NLNN,若輸入信號為3D特徵圖,則每個空間位置j對應一個向量,故論文計算的是x_ix_j兩個1D向量的點積再對得到的標量過一個指數函數。

  接下來回到NLNN論文中,首先看一下作者給出的通用non-local operation模型:
y_i=\frac{1}{C(x)}\Sigma_j f(x_i, x_j)g(x_j) \tag{4}

  其中,i代表當前待計算的output position的index,x代表輸入信號(圖像、序列、視頻或者他們的特徵),y為輸出信號,通常和x 保持size相同。二元函數 f 輸出一個標量,如兩個位置i,j的某種接近程度的metric。一元函數負責給出輸入信號在j位置處的一個表示。計算出的輸出信號在i處的響應最後會經過一個歸一化因子C(x)以進行歸一化。

  上式從形式上看與NLM的模型,即式(1)沒有差異。接下來作者在”Instantiations”部分為該通用模版舉了幾個具體的例子,使之更易理解,同時也用於說明non local block的通用性。

  g(x_j)代表位置j處輸入信號的一個表示,為了簡潔起見,作者就直接把其形式固定為原始輸入的一個embedding:g(x_j)=W_gx_jW_jx_j對應的權重矩陣,如最簡單的,對於CNN中的feature,W_j可以是1×1卷積(或者進一步擴展到多個1\times1卷積,即由一個卷積層提取之後的新特徵?)

  然後作者討論了對f(x_i, x_j)函數的選擇。這裡有必要再對NLM和NLNN的區別做一個強調:

  • 輸入信號形式: NLM輸入信號形式固定,為2D圖像,NLNN輸入信號形式多樣,可以是2D(第一層輸入特徵)、3D(中間層輸入特徵)甚至4D(加入時序的中間層特徵)
  • 計算相似程度時有無使用鄰域信息: f(x_i, x_j)需要輸出一個標量代表i,j兩個位置的相似程度。由於NLM輸入形式簡單,x_i, x_j均只代表一個灰度數值,故求相似程度時利用鄰域信息提高魯棒性。NLM在使用鄰域信息時為了突出中心像素坐標,所以要用gaussian kernel。而對於NLNN來說x_i,x_j各自代表一個向量,求兩向量之間的相似程度來表徵x_i, x_j兩個位置的相似程度本身就比較魯棒(當然或許可以把NLNN也擴展為求$K_i,K_j兩個向量鄰域對相似程度?這樣從思想上來說感覺更接近NLM),所以NLNN未利用鄰域信息。

  所以NLNN論文中的gaussian與embedded gaussian或許改成exponential與embedded exponential更貼切?畢竟後面還把式子轉化成softmax,softmax中的exp與gaussian有隱含關係嗎?

2.3 NLNN block公式推導

2.3.1 權重與softmax

  接下來就是這一章最重要的內容:將上述公式(以embedded gaussian為例子)簡化為softmax的形式。為了便於理解,這裡可以以我們比較熟悉的CNN某中間層特徵來作為輸入信號,其shape為HxWx1024。對於i,j兩個位置,計算其embedded gausssian:
z_i = \theta(x_i) = W_ix_i

z_j = \phi(x_j) = W_jx_j

對於輸入信號的某個位置i,因為i確定是C為常數,故可寫成:
y_i=\sum_{\forall{j}}\frac{1}{C}e^{z_i^Tz_j}g(x_j)

然後把C代入進來:
y_i=\sum_{\forall{j}}\frac{e^{z_i^Tz_j}}{\Sigma_je^{z_i^Tz_j}} g(x_j)

可以看出對於固定的i,中間那一塊符合softmax的定義,因此有:
y_i=\sum_{\forall{j}} softmax(z_i^Tz_j)g(x_j)

至此我們推導了對於單個位置i,輸出信號的響應y(i)。接下來需要對這個公式進行向量化。大家在寫反向傳播相關練習時都知道,單個樣本的公式比較好推導,但是如果要在代碼中實現,還是需要向量化之後的公式。

2.3.2 公式向量化

  首先保持i固定,去掉j
y_i = sum(softmax(z_i^TZ)*g(X))

  注意,這裡的ZX本身是3維向量C\times H \times W,但是如果推導時直接用3維的shape,後面會非常麻煩。故這裡的一個技巧是將g(X)Z均看成C×HW的矩陣(假設g(X)使用的embedding也是C個kernel)。這樣sum就是沿着列方向進行。

  可以先檢查下各個向量的維度:z_i^T1×CZC×HW,softmax相當於是activation不改變維度,故經過softmax之後的維度為1×HW,然後g(X)C×HW,故對於每個i,通過廣播得到C×HW的矩陣,然後對列求和得到輸出響應。

  然後嘗試去除i,得到最終的向量化公式:
  為了實現對樣本維度的向量化,一種個人比較喜歡的做法是先將i可能的取值代進去,並排寫到一起觀察規律,比如下面的草稿:

Fig. 4 將yi展開觀察規律

看起來要求Y=[y_1, y_2, ..., y_{hw}]^T,需要做的就是對每一個ix_i^TX向量(注意這張圖中的x相當於上面用的z記號)分別與g(X)廣播相乘並把結果矩陣堆到一起,再對列求和即可,也就是下圖的第一行:

Fig. 5 第一行矩陣運算結果等效於第二行的矩陣乘法(A、B、C代表維度)

(圖示:第一個矩陣的每一行分別與第二個矩陣相乘(利用廣播),得到的A個CxB的矩陣concat到一起,再沿着列方向求和)
  這裡直接給出結論,第一行的結果和第二行的結果等價。感興趣的同學可以證明下?(我挑了幾個位置對比了下是一樣的hh)

  到這裡,我們可以寫出去掉i之後的向量化公式:
Y = g(X)·[softmax(Z^T·Z)]^T

其中”·”代表dot product,即矩陣乘法,其在向量化過程中起到關鍵的一步,並且將讓人頭疼的sum去掉了。這樣我們就可以寫出NLNN關鍵組件的代碼,同時也可以看懂論文中的Figure2,尤其是為什麼右上那個乘法用的是矩陣乘法。(注意,看該圖時可以忽略掉維度T,並將embedding所用權重簡化為1×1,這樣會稍微容易理解一些)

Fig. 6 NLNN論文中的figure 2

  嚴格來說,根據輸出響應的定義,softmax之後那個矩陣乘法的結果就是當f(x_i, x_j)採取embedded gaussia時的輸出信號。但是這張圖後面還有個通過1024個1×1 conv將輸出映射回emedding之前的維度,然後加上原始輸入信號。個人理解這裡最左邊那條通路以及逐元素相加運算應該算是構成了NLNN版「卷積」的shortcut。

【待續】

image.png

四、記錄

對長距離依賴關係的建模方式:

  • 堆疊卷積層
  • 使用循環網絡組件
  • graphic models,如條件隨機場、GNN

詞語記錄:

  • showcase:v. display/exhibit
  • amenable:合適的,易順服的
  • affinity:n. 接近,親和力

參考文獻:
//developers.google.com/machine-learning/clustering/similarity/measuring-similarity
//blog.csdn.net/piaoxuezhong/java/article/details/78345929