針對深度學習的「失憶症」,科學家提出基於相似性加權交錯學習,登上PNAS

  • 2022 年 7 月 5 日
  • AI

與人類不同,人工神經網路在學習新事物時會迅速遺忘先前學到的資訊,必須通過新舊資訊的交錯來重新訓練;但是,交錯全部舊資訊非常耗時,並且可能沒有必要。只交錯與新資訊有實質相似性的舊資訊可能就足夠了。

近日,美國科學院院報(PNAS)刊登了一篇論文,「Learning in deep neural networks and brains with similarity-weighted interleaved learning」,由加拿大皇家學會會士、知名神經科學家 Bruce McNaughton 的團隊發表。他們的工作發現,通過將舊資訊與新資訊進行相似性加權交錯訓練,深度網路可以快速學習新事物,不僅降低了遺忘率,而且使用的數據量大幅減少。


論文作者還作出一個假設:通過跟蹤最近活躍的神經元和神經動力學吸引子(attractor dynamics)的持續興奮性軌跡,可以在大腦中實現相似性加權交錯。這些發現可能會促進神經科學和機器學習的進一步發展。

作者 | Rajat Saxena et al.
編譯 | bluemin
編輯 | 陳彩嫻


1

研究背景

了解大腦如何終身學習仍然是一項長期挑戰。
在人工神經網路(ANN)中,過快地整合新資訊會產生災難性干擾,即先前獲得的知識突然丟失。互補學習系統理論 (Complementary Learning Systems Theory,CLST) 表明,通過將新記憶與現有知識交錯,新記憶可以逐漸融入新皮質。
CLST指出,大腦依賴於互補的學習系統:海馬體 (HC) 用於快速獲取新記憶,新皮層 (NC) 用於將新數據逐漸整合到與上下文無關的結構化知識中。在「離線期間」,例如睡眠和安靜的清醒休息期間,HC觸發回放最近在NC中的經歷,而NC自發地檢索和交錯現有類別的表徵。交錯回放允許以梯度下降的方式逐步調整NC突觸權重,以創建與上下文無關的類別表徵,從而優雅地整合新記憶並克服災難性干擾。許多研究已經成功地使用交錯回放實現了神經網路的終身學習。
然而,在實踐中應用CLST時,有兩個重要問題亟待解決。首先,當大腦無法訪問所有舊數據時,如何進行全面的資訊交錯呢?一種可能的解決方案是「偽排練」,其中隨機輸入可以引發內部表徵的生成式回放,而無需顯式訪問先前學習的示例。類吸引子動力學可能使大腦完成「偽排練」,但「偽排練」的內容尚未明確。因此,第二個問題是,每進行新的學習活動之後,大腦是否有充足的時間交織所有先前學習的資訊。
相似性加權交錯學習(Similarity-Weighted Interleaved Learning,SWIL)演算法被認為是第二個問題的解決方案,這表明僅交錯與新資訊具有實質表徵相似性的舊資訊可能就足夠了。實證行為研究表明,高度一致的新資訊可以快速整合到NC結構化知識中,幾乎沒有干擾。這表明整合新資訊的速度取決於其與先驗知識的一致性。受此行為結果的啟發,並通過重新檢查先前獲得的類別之間的災難性干擾分布,McClelland等人證明SWIL可以在具有兩個上義詞類別(例如,「水果」是「蘋果」和「香蕉」的上義詞)的簡單數據集中,每個epoch使用少於2.5倍的數據量學習新資訊,實現了與在全部數據上訓練網路相同的性能。然而,研究人員在使用更複雜的數據集時並沒有發現類似的效果,這引發了對該演算法可擴展性的擔憂。
實驗表明,深度非線性人工神經網路可以通過僅交錯與新資訊共享大量表徵相似性的舊資訊子集來學習新資訊。通過使用SWIL演算法,ANN能夠以相似的精度水平和最小的干擾快速學習新資訊,同時使用的每個時期呈現的舊資訊量少之又少,這意味著數據利用率高且可以快速學習。
同時,SWIL也可應用於序列學習框架。此外,學習一種新類別可以極大地提高數據利用率 。如果舊資訊與之前學習過的類別有著非常少的相似性,那麼呈現的舊資訊數量就會少得多,這很可能是人類學習的實際情況。
最後,作者提出了一個關於SWIL如何在大腦中實現的理論模型,其興奮性偏差與新資訊的重疊成正比。



2

應用於影像分類數據集的

DNN動力學模型

McClelland等人的實驗表明,在具有一個隱藏層的深度線性網路中,SWIL可以學習一個新類別,類似於完全交錯學習 (Fully Interleaved Learning,FIL),即將整箇舊類別與新類別交錯,但使用的數據量減少了40%。
然而,網路是在一個非常簡單的數據集上訓練的,只有兩個上義詞類別,這就對演算法的可擴展性提出了疑問。
首先針對更複雜的數據集(如Fashion-MNIST),探索不同類別的學習在具有一個隱藏層的深度線性神經網路中如何演變。移出了「boot」(「靴子」)和「bag」(「紙袋」)類別後,該模型在剩餘的8個類別上的測試準確率達到了87%。然後作者團隊重新訓練模型,在兩種不同的條件下學習(新的)「boot」類,每個條件重複10次:
1)集中學習(Focused Learning ,FoL),即僅呈現新的「boot」類;
2)完全交錯學習 (FIL),即所有類別(新類別+以前學過的類別)以相等的概率呈現。在這兩種情況下,每個epoch總共呈現180張影像,每個epoch中的影像相同。
該網路在總共9000張從未見過的影像上進行了測試,其中測試數據集由每類1000張影像組成,不包括「bag」類別。當網路的性能達到漸近線時,訓練停止。
不出所料,FoL對舊類別造成了干擾,而FIL克服了這一點(圖1第2列)。如上所述,FoL對舊數據的干擾因類別而異,這是SWIL最初靈感的一部分,並表明新「boot」類別和舊類別之間存在分級相似關係。例如,「sneaker」(「運動鞋」)和「sandals」(「涼鞋」)的召回率比「trouser」(「褲子」)下降得更快(圖1第2列),可能是因為整合新的「boot」類會選擇性地改變代表「sneaker」和「sandals」類的突觸權重,從而造成更多的干擾。

圖1:預訓練網路在兩種情況下學習新「boot」類的性能對比分析:FoL(上)和 FIL(下)。從左到右依次為預測新「boot」類別的召回率(橄欖色)、現有類別的召回率(用不同顏色繪製)、總準確度(高分意味著低誤差)和交叉熵損失(總誤差的度量)曲線,是保留的測試數據集上與epoch數有關的函數。



3

計算不同類別之間的相似度

FoL在學習新類別的時候,在相似的舊類別上的分類性能會大幅下降。
之前已經探討了多類別屬性相似度和學習之間的關係,並且表明深度線性網路可以快速獲取已知的一致屬性。相比之下,在現有類別層次結構中添加新分支的不一致屬性,需要緩慢、漸進、交錯的學習。
在當前的工作中,作者團隊使用已提出的方法在特徵級別計算相似度。簡言之,計算目標隱藏層(通常是倒數第二層)現有類別和新類別的平均每類激活向量之間的餘弦相似度。圖2A顯示了基於Fashion MNIST數據集的新「boot」類別和舊類別,作者團隊根據預訓練網路的倒數第二層激活函數計算的相似度矩陣。
類別之間的相似性與我們對物體的視覺感知一致。例如,在層次聚類圖(圖2B)中,我們可以觀察到「boot」類與「sneaker」和「sandal」類之間、以及「shirt」(「襯衫」)和「t-shirt」(「T恤」)類之間具有較高的相似性。相似度矩陣(圖2A)與混淆矩陣(圖2C)完全對應。相似度越高,越容易混淆,例如,「襯衫」類與「T恤」、「套頭衫」和「外套」類影像容易混淆,這表明相似性度量預測了神經網路的學習動態。
在上一節的FoL結果圖(圖1)中,舊類別的召回率曲線中存在相近的類相似度曲線。與不同的舊類別(「trouser」等)相比,FoL學習新「boot」類的時候會快速遺忘相似的舊類別(「sneaker」 和 「sandal」)。

圖2:( A ) 作者團隊根據預訓練網路的倒數第二層激活函數,計算的現有類別和新「boot」類的相似度矩陣,其中對角線值(同一類別的相似性繪製為白色)被刪除。( B ) 對A中的相似矩陣進行層次聚類。( C ) FIL演算法在訓練學習「boot」類後生成的混淆矩陣。為了縮放清晰,刪除了對角線值。



4

深度線性神經網路實現快速和

高效學習新事物

接下來在前兩個條件基礎上增加了3種新條件,研究了新的分類學習動態,其中每個條件重複10次:
1)FoL(共計n=6000張影像/epoch);
2) FIL(共計n=54000張影像/epoch,6000張影像/類);
3) 部分交錯學習 (Partial Interleaved Learning,PIL)使用了很小的影像子集(共計n=350張影像/epoch,大約39張影像/類),每一類別(新類別+現有類別)的影像以相等的概率呈現;
4) SWIL,每個epoch使用與PIL 相同的影像總數進行重新訓練,但根據與(新)「boot」類別的相似性對現有類別影像進行加權;
5)等權交錯學習(Equally Weighted Interleaved Learning,EqWIL),使用與SWIL相同數量的「boot」類影像重新訓練,但現有類別影像的權重相同(圖3A)。
作者團隊使用了上述相同的測試數據集(共有n=9000張影像)。當在每種條件下神經網路的性能都達到漸近線時,停止訓練。儘管每個epoch使用的訓練數據較少,預測新「boot」類的準確率需要更長的時間達到漸近線,與FIL(H=7.27,P<0.05)相比,PIL的召回率更低(圖3B第1列和表1「New class」列)。
對於SWIL,相似度計算用於確定要交錯的現有舊類別影像的比例。在此基礎上,作者團隊從每箇舊類別中隨機抽取具有加權概率的輸入影像。與其他類別相比,「sneaker」和「sandal」類最相似,從而導致被交錯的比例更高(圖3A)。
根據樹狀圖(圖2B),作者團隊將「sneaker」和「sandal」類稱為相似的舊類,其餘則稱為不同的舊類。與PIL(H=5.44,P<0.05)相比,使用SWIL時,模型學習新「boot」類的速度更快,對現有類別的干擾也相近。此外,SWIL(H=0.056,P>0.05)的新類別召回率(圖3B第1列和表1「New class」列)、總準確率和損失與FIL相當。EqWIL(H=10.99,P<0.05)中新「boot」類的學習與SWIL相同,但對相近的舊類別有更大程度的干擾(圖3B第2列和表1「Similar old class」列)。
作者團隊使用以下兩種方法比較SWIL和FIL:
1) 記憶體比,即FIL和SWIL中存儲的影像數量之比,表示存儲的數據量減少;
2) 加速比,即在FIL和SWIL中呈現的內容總數的比率,以達到新類別回憶的飽和精度,表明學習新類別所需的時間減少。
SWIL可以在數據需求減少的情況下學習新內容,記憶體比=154.3x (54000/350),並且速度更快,加速比=77.1x (54000/(350×2))。即使和新內容有關的影像數量較少,該模型也可以通過使用SWIL,利用模型先驗知識的層次結構實現相同的性能。SWIL在PIL和EqWIL之間提供了一個中間緩衝區,允許集成一個新類別,並將對現有類別的干擾降到最低。

圖3 ( A ) 作者團隊在五種不同的學習條件下預訓練神經網路學習新的「boot」類(橄欖綠),直到性能平穩:1)FoL(共計n=6000張影像/epoch);2)FIL(共計n=54000張影像/epoch);3) PIL(共計n=350張影像/epoch);4) SWIL(共計n=350張影像/epoch)和 5) EqWIL(共計n=350張影像/epoch)。(B)FoL(黑色)、FIL(藍色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預測新類別、相似舊類別(「sneaker」和「sandals」)和不同舊類別的召回率,預測所有類別的總準確率,以及在測試數據集上的交叉熵損失,其中橫坐標都是epoch數。



5

基於CIFAR10使用SWIL

在CNN中學習新類別

接下來,為了測試SWIL是否可以在更複雜的環境中工作,作者團隊訓練了一個具有全連接輸出層的6層非線性CNN(圖4A),以識別CIFAR10數據集中剩餘8個不同類別(「cat」和「car」除外)的影像。他們還對模型進行了重新訓練,在之前定義的5種不同訓練條件(FoL、FIL、PIL、SWIL和EqWIL)下學習「cat」(「貓」)類。圖4C顯示了5種情況下每類影像的分布。對於SWIL、PIL和EqWIL條件,每個epoch的總影像數為2400,而對於FIL和FoL,每個epoch的總影像數分別為45000和5000。作者團隊針對每種情況對網路分別進行訓練,直到性能趨於穩定。
他們在之前未見過的總共9000張影像(1000張影像/類,不包括「car」(「轎車」)類)上對該模型進行了測試。圖4B是作者團隊基於CIFAR10數據集計算的相似性矩陣。「cat」類和「dog」(「狗」)類更類似,而其他動物類屬於同一分支(圖4B左)。
根據樹狀圖(圖4B),將「truck」 (「貨車」)、「ship」(「輪船」) 和 「plane」(「飛機」) 類別稱為不同的舊類別,除「cat」類外其餘的動物類別稱為相似的舊類別。對於FoL,模型學習了新的「cat」類,但遺忘了舊類別。與Fashion-MNIST數據集結果類似,「dog」類(與「cat」類相似性最大)和「truck」類(與「cat」類相似性最小)均存在干擾梯度,其中「dog」類的遺忘率最高,而「truck」類遺忘率最低。
如圖4D所示,FIL演算法學習新的「cat」類時克服了災難性的干擾。對於PIL演算法,模型在每個epoch使用18.75倍的數據量學習新的「cat」類,但「cat」類的召回率比FIL(H=5.72,P<0.05)低。對於SWIL,在新類別、相似和不同舊類別上的召回率、總準確率和損失與FIL相當(H=0.42,P>0.05;見表2和圖4D)。SWIL對新「cat」類的召回率高於PIL(H=7.89,P<0.05)。使用EqWIL演算法時,新「cat」類的學習情況與SWIL和FIL相似,但對相似舊類別的干擾較大(H=24.77,P<0.05;見表2)。
FIL、PIL、SWIL和EqWIL這4種演算法預測不同舊類別的性能相當(H=0.6,P>0.05)。SWI比PIL更好地融合了新的「cat」類,並有助於克服EqWIL中的觀測干擾。與FIL相比,使用SWIL學習新類別速度更快,加速比=31.25x (45000×10/(2400×6)),同時使用更少的數據量 (記憶體比=18.75x)。這些結果證明,即使在非線性CNN和更真實的數據集上,SWIL也可以有效學習新類別事物。

圖4:( A ) 作者團隊使用具有全連接輸出層的6層非線性CNN學習CIFAR10數據集中的8類事物。( B ) 相似度矩陣 (右)是在呈現新的「cat」類之後,作者團隊根據最後一個卷積層的激活函數計算獲得。對相似矩陣應用層次聚類(左),在樹狀圖中顯示動物(橄欖綠)和交通工具(藍色)兩個上義詞類別的分組情況。( C ) 作者團隊在5種不同的條件下預訓練CNN學習新的「cat」類(橄欖綠),直到性能平穩:1)FoL(共計n=5000張影像/epoch);2)FIL(共計n=45000張影像/epoch);3) PIL(共計n=2400張影像/epoch);4) SWIL(共計n=2400張影像/epoch);5) EqWIL(共計n=2400張影像/epoch)。每個條件重複10次。(D)FoL(黑色)、FIL(藍色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預測新類別、相似舊類別(CIFAR10數據集中的其他動物類)和不同舊類別(「plane」 、「ship」 和 「truck」)的召回率,預測所有類別的總準確率,以及在測試數據集上的交叉熵損失,其中橫坐標都是epoch數。



6

新內容與舊類別的一致性

對學習時間和所需數據的影響

如果一項新內容可以添加到先前學習過的類別中,而不需要對網路進行較大更改,則稱二者具有一致性。基於此框架,與干擾多個現有類別(低一致性)的新類別相比,學習干擾更少現有類別(高一致性)的新類別可以更容易地集成到網路中。
為了測試上述推斷,作者團隊使用上一節中經過預訓練的CNN,在前面描述的所有5種學習條件下,學習了一個新的「car」類別。圖5A顯示了「car」類別的相似性矩陣,與其他現有類別相比,「car」和「truck」、「ship」和「plane」在同一層次節點下,說明它們更相似。為了進一步確認,作者團隊在用於相似性計算的激活層上進行了t-SNE降維可視化分析(圖5B)。研究發現「car」類與其他交通工具類(「truck」、「ship」和「plane」)有顯著重疊,而「cat」類與其他動物類(「dog」、 「frog」(「青蛙」)、「horse」(「馬」)、「bird」(「鳥」)和「deer」(「鹿」))有重疊。
和作者團隊預期相符,FoL學習「car」類別時會產生災難性干擾,對相近的舊類別干擾性更強,而使用FIL克服了這一點(圖5D)。對於PIL、SWIL和EqWIL,每個epoch總共有n=2000張影像(圖5C)。使用SWIL演算法,模型學習新的「car」類別可以達到和FIL(H=0.79,P>0.05)相近的精度,而對現有類別(包括相似和不同類別)的干擾最小。如圖5D第2列所示,使用EqWIL,模型學習新「car」類的方式與SWIL相同,但對其他相似類別(例如「truck」)的干擾程度更高(H=53.81,P<0.05)。
與FIL相比,SWIL可以更快地學習新內容加速比=48.75x(45000×12/(2000×6)),記憶體需求減少,記憶體比=22.5x。與「cat」(48.75x vs.31.25x)相比,「car」可以通過交錯更少的類(如「truck」、「ship」和「plane」)更快地學習,而「cat」與更多的類別(如「dog」 、「frog」 、「horse」 、「frog」 和「deer」)重疊。這些模擬實驗表明,交叉和加速學習新類別所需的舊類別數據量,取決於新資訊與先驗知識的一致性。

圖 5:( A ) 作者團隊根據倒數第二層激活函數計算獲得相似度矩陣(左),以及呈現新的「car」類別後對相似度矩陣進行層次聚類後的結果圖(右)。( B ) 模型分別學習新的「car」類別和「cat」類別,經過最後一個卷積層過激活函數後,作者團隊進行t-SNE降維可視化的結果圖。( C ) 作者團隊在5種不同的條件下預訓練CNN學習新的「car」類(橄欖綠),直到性能平穩:1)FoL(共計n=5000張影像/epoch);2)FIL(共計n=45000張影像/epoch);3) PIL(共計n=2000張影像/epoch);4) SWIL(共計n=2000張影像/epoch);5) EqWIL(共計n=2000張影像/epoch)。(D)FoL(黑色)、FIL(藍色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預測新類別、相似舊類別(「plane」 、「ship」 和 「truck」)和不同舊類別(CIFAR10數據集中的其他動物類)的召回率,預測所有類別的總準確率,以及在測試數據集上的交叉熵損失,其中橫坐標都是epoch數。每張圖顯示的是重複10次後的平均值,陰影區域為±1 SEM。



7

利用SWIL進行序列學習

接下來,作者團隊測試是否可以使用SWIL學習序列化形式呈現的新內容(序列學習框架)。為此他們採用了圖4中經過訓練的CNN模型,在FIL和SWIL條件下學習CIFAR10數據集中的「cat」類(任務1),只在CIFAR10的剩餘9個類別上訓練,然後在每個條件下訓練模型學習新的「car」類(任務2)。圖6第1列顯示了SWIL條件下學習「car」類別時,其他各項類別的影像數量分布情況(共計n=2500張影像/epoch)。需要注意的是,預測「cat」類時也交叉學習新的「car」類。由於在FIL條件下模型性能最佳,SWIL僅與FIL進行了結果比較。
如圖6所示,SWIL預測新、舊類別的能力與FIL相當(H=14.3,P>0.05)。模型使用SWIL演算法可以更快地學習新的「car」類別,加速比為45x(50000×20/(2500×8)),每個epoch的記憶體佔用比FIL少20倍。模型學習「cat」和「car」類別時,在SWIL條件下每個epoch使用的影像數量(記憶體比和加速比分別為18.75x 和 20x),少於在FIL條件下每個epoch使用的整個數據集(記憶體比和加速比分別為31.25x 和45x),並且仍然可以快速學習新類別。擴展這一思想,隨著學過的類別數目不斷增加,作者團隊預期模型的學習時間和數據存儲會成倍減少,從而更高效地學習新類別,這或許反映了人類大腦實際學習時的情況。
實驗結果表明,SWIL可在序列學習框架中集成多個新類,使神經網路能夠在不受干擾的情況下持續學習。
圖6:作者團隊訓練6層CNN學習新的「cat」類(任務1),然後學習「car」類(任務2),直到性能在以下兩種情況下趨於穩定:1)FIL:包含所有舊類別(以不同顏色繪製)和以相同概率呈現的新類別(「cat」/「car」)影像;2) SWIL:根據與新類別(「cat」/「car」)的相似性進行加權並按比例使用舊類別示例。同時將任務1中學習的「cat」類包括在內,並根據任務2中學習「car」類的相似性進行加權。第1張子圖表示每個epoch使用的影像數量分布情況,其餘各子圖分別表示FIL(藍色)和SWIL(洋紅色)預測新類別、相似舊類別和不同舊類別的召回率,預測所有類別的總準確率,以及在測試數據集上的交叉熵損失,其中橫坐標都是epoch數。



8

利用SWIL擴大類別間的距離,

減少學習時間和數據量

作者團隊最後測試了SWIL演算法的泛化性,驗證其是否可以學習包括更多類別的數據集,以及是否適用於更複雜的網路架構。
他們在CIFAR100數據集(訓練集500張影像/類,測試集100張影像/類)上訓練了一個複雜的CNN模型-VGG19(共有19層),學習了其中的90個類別。然後對網路進行再訓練,學習新類別。圖7A顯示了基於CIFAR100數據集,作者團隊根據倒數第二層的激活函數計算的相似性矩陣。如圖7B所示,新「train」(「火車」)類與許多現有的交通工具類別(如「bus」 (「公共汽車」)、「streetcar」 (「有軌電車」)和「tractor」(「拖拉機」)等)很相似。
與FIL相比,SWIL可以更快地學習新事物(加速比=95.45x (45500×6/(1430×2)))並且使用的數據量 (記憶體比=31.8x) 顯著減少,而性能基本相同(H=8.21, P>0.05) 。如圖7C所示,在PIL(H=10.34,P<0.05)和EqWIL(H=24.77,P<0.05)條件下,模型預測新類別的召回率較低並且產生的干擾較大,而SWIL克服了上述不足。
同時,為了探索不同類別表徵之間的較大距離是否構成了加速模型學習的基本條件,作者團隊另外訓練了兩種神經網路模型:
1)6層CNN(與基於CIFAR10的圖4和圖5相同);
2)VGG11(11層)學習CIFAR100數據集中的90個類別,僅在FIL和SWIL兩個條件下對新的「train」類進行訓練。
如圖7B所示,對於上述兩種網路模型,新的「train」類和交通工具類別之間的重疊度更高,但與VGG19模型相比,各類別的分離度較低。與FIL相比,SWIL學習新事物的速度與層數的增加大致呈線性關係(斜率=0.84)。該結果表明,類別間表徵距離的增加可以加速學習並減少記憶體負載。

圖7:( A ) VGG19學習新的「train」類後,作者團隊根據倒數第二層激活函數計算的相似性矩陣。「truck」 、「streetcar」 、「bus」 、「house」 和 「tractor」5種類別與「train」的相似性最大。從相似度矩陣中排除對角元素(相似度 =1)。(B,左)作者團隊針對6層CNN、VGG11和VGG19網路,經過倒數第二層激活函數後,進行t-SNE降維可視化的結果圖。(B,右)縱軸表示加速比(FIL/SWIL),橫軸表示3個不同網路的層數相對於6層CNN的比率。黑色虛線、紅色虛線和藍色實線分別代表斜率 =1的標準線、最佳擬合線和模擬結果。( C ) VGG19模型的學習情況:FoL(黑色)、FIL(藍色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預測新「train」類、相似舊類別(交通工具類別)和不同舊類別(除了交通工具類別)的召回率,預測所有類別的總準確率,以及在測試數據集上的交叉熵損失,其中橫坐標都是epoch數。每張圖顯示的是重複10次後的平均值,陰影區域為±1 SEM。( D ) 從左到右依次表示模型預測Fashion-MNIST「boot」類(圖3)、CIFAR10「cat」類(圖4)、CIFAR10「car」類(圖5)和CIFAR100「train」類的召回率,是SWIL(洋紅色)和FIL(藍色)使用的影像總數(對數比例)的函數。「N」表示每種學習條件下每個epoch使用的影像總數(包括新、舊類別)。
如果在更多非重疊類上訓練網路,並且各表徵之間的距離更大,速度是否會進一步提升?
為此,作者團隊採用了一個深度線性網路(用於圖1-3中的Fashion-MNIST示例),並對其進行訓練,以學習由8個Fashion-MNIST類別(不包括「bags」和「boot」類)和10個Digit-MNIST類別形成的組合數據集,然後訓練網路學習新的「boot」類別。
和作者團隊的預期相符,「boot」與舊類別「sandals」和「sneaker」相似度更高,其次是其餘的Fashion-MNIST類(主要包括服飾類影像),最後Digit-MNIST類(主要包括數字類影像)。
基於此,作者團隊首先交織了更多相似的舊類別樣本,再交織Fashion-MNIST和Digit-MNIST類樣本(共計n=350張影像/epoch)。實驗結果表明,與FIL類似,SWIL可以快速學習新類別內容而不受干擾,但使用的數據子集要小得多,記憶體比為325.7x (114000/350) ,加速比為162.85x (228000/1400)。作者團隊在當前結果中觀察到的加速比為2.1x (162.85/77.1),與Fashion-MNIST數據集相比,類別數目增加了 2.25倍 (18/8)。
本節的實驗結果有助於確定SWIL可以適用於更複雜的數據集 (CIFAR100) 和神經網路模型(VGG19),證明了該演算法的泛化性。同時證明了擴大類別之間的內部距離或增加非重疊類別的數量,可能會進一步提高學習速度並降低記憶體負載。



9

總結

人工神經網路在持續學習方面面臨重大挑戰,通常表現出災難性干擾。為了克服此問題,許多研究都使用了完全交錯學習(FIL),即新舊內容交叉學習,聯合訓練網路。FIL需要在每次學新資訊時交織所有現有資訊,使其成為一個生物學意義上不可信且耗時的過程。最近,有研究表明FIL可能並非必需,僅交錯與新內容具有實質表徵相似性的舊內容,即採用相似性加權交錯學習(SWIL)的方法可以達到相同的學習效果。然而,有人對SWIL的可擴展性表示了擔憂。
本文擴展了SWIL演算法,並基於不同的數據集(Fashion-MNIST、CIFAR10 和 CIFAR100)和神經網路模型(深度線性網路和CNN)對其進行了測試。在所有條件下,與部分交錯學習(PIL)相比,相似性加權交錯學習(SWIL)和等權交錯學習(EqWIL)在學習新類別方面的表現更好。這和作者團隊的預期相符,因為與舊類別相比,SWIL和EqWIL增加了新類別的相對頻率。
本文同時還證明,與同等子抽樣現有類別(即EqWIL方法)相比,仔細選擇和交織相似內容減少了對相近舊類別的災難性干擾。在預測新類別和現有類別方面,SWIL的性能與FIL類似,卻顯著加快了學習新內容的速度(圖7D),同時大大減少了所需的訓練數據。SWIL可以在序列學習框架中學習新類別,進一步證明了其泛化能力。
最後,與許多舊類別具有相似性的新類別相比,如果其與之前學過的類別重疊更少(距離更大),可以縮短集成時間,並且數據效率更高。總體來說,實驗結果提供了一種可能的見解,即大腦事實上通過減少不切實際的訓練時間,克服了原始CLST模型的一項主要弱點。
原文鏈接:

//www.pnas.org/doi/10.1073/pnas.2115229119

雷峰網