三四行程式碼打造元學習核心,PyTorch元學習庫L2L現已開源

  • 2019 年 10 月 4 日
  • 筆記

元學習似乎一直比較「高級」,畢竟學習如何學習這個概念聽起來就很難實現。在本文中,我們介紹了這兩天新開源的元學習庫 learn2learn,它是用 PyTorch 寫的,只需要三四行程式碼就能構建元學習最為核心的部分。

learn2learn 是一個用於實現元學習的 Pytorch 庫,我們只需要加幾行高層 API,就能為一般的機器學習流程添加元學習能力。例如在元學習 MNIST 案例中,我們可以用 PyTorch 構建整個流程,但只要加上三行 L2L 程式碼就能打造元學習模型。這三行程式碼只干三件事:獲取元數據集、生成元學習任務、定義元學習模型。

  • 項目地址:https://github.com/learnables/learn2learn

元學習的目標是讓智慧體學習如何學習,也就是說,我們希望智慧體能夠在解決更多問題的過程中成為更好的學習器。例如,下圖展示的智慧體正在學習如何跑步,儘管它只會更新一個參數。

L2L 有什麼特性

L2L 是一個元學習庫,可以為用戶提供 3 個級別的功能。在最高級別上,它有很多使用元學習演算法在大量數據集/環境上訓練的示例。在中間級別上,它為若干流行的元學習演算法提供了功能介面以及便於載入其他數據集的數據載入器。在最低級別上,它為模組提供了可擴展功能。

L2L 的一些特性包括:

  • 模組化 API:使用這個庫中的底層工具實現你自己的訓練循環;
  • 提供多個元學習演算法(如 MAML、FOMAML、MetaSGD、ProtoNets、DiCE);
  • 具有統一 API 的任務生成器,兼容 torchvision、torchtext、torchaudio 和 cherry;
  • 提供標準化的視覺(Omniglot、mini-ImageNet)、強化學習(Particles、Mujoco)甚至文本(新聞分類)元學習任務;
  • 100% 兼容 PyTorch——使用你自己的模組、數據集或庫。

最後,整個 L2L 庫都是由 PyTorch 寫的,因此它的源程式碼並不難理解,我們可以通過項目的源碼學習怎樣從底層實現元學習演算法。

L2L 實現 MAML 元學習演算法的局部源程式碼,它的源碼擁有大量的注釋,可以幫助理解實現過程。

示例程式碼

下面我們來看看 learn2learn 到底該如何學習一個能實現 MNIST 分類任務的模型,它使用非常高層的應用,因此理解起來很容易。

如下程式碼所示,總體而言,整個過程可以分為導入數據、定義元學習任務、定義元學習模型與最優化方法、在元學習任務內不同的學習器適配不同的數據,最後就是標準的損失計算與模型更新了。

import learn2learn as l2l    mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)    mnist = l2l.data.MetaDataset(mnist)  task_generator = l2l.data.TaskGenerator(mnist,                                          ways=3,                                          classes=[0, 1, 4, 6, 8, 9],                                          tasks=10)  model = Net()  maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)  opt = optim.Adam(maml.parameters(), lr=4e-3)    for iteration in range(num_iterations):      learner = maml.clone()  # Creates a clone of model      adaptation_task = task_generator.sample(shots=1)        # Fast adapt      for step in range(adaptation_steps):          error = compute_loss(adaptation_task)          learner.adapt(error)        # Compute evaluation loss      evaluation_task = task_generator.sample(shots=1,                                              task=adaptation_task.sampled_task)      evaluation_error = compute_loss(evaluation_task)        # Meta-update the model parameters      opt.zero_grad()      evaluation_error.backward()      opt.step()  

整個 API 非常高層,只需要很少的程式碼量就能完成模型。但與此同時,L2L 庫還提供了中層和底層方面的 API,它允許我們做更多訂製化的修改。更多的例子讀者可以在 GitHub 中查閱,其示例模型分為強化學習、文本處理和視覺模型三方面:

如果讀者也想要試試這個庫,那麼直接在命令行中運行 pip install learn2learn 就行了,剩下的再看看文檔和教程,就可以快速學會怎樣使用元學習。

  • 文檔地址:http://learn2learn.net/docs/learn2learn/
  • 教程地址:http://learn2learn.net/tutorials/getting_started/