強化學習入門(五)連續動作空間內,使用DDPG求解強化學習問題

本文內容源自百度強化學習 7 日入門課程學習整理
感謝百度 PARL 團隊李科澆老師的課程講解

一、離散動作 VS 連續動作

1.1 區別

離散動作:

  • 動作都是可數的
  • 比如在 CartPole 環境中,向左推動小車,或者向右推動(力量是不變的)
  • 比如在 FrozenLake 環境中,向上下左右 4 個方向移動
  • 在 Atari 的 Pong 環境中,球拍上下移動

連續動作:

  • 動作是連續的浮點數
  • 比如在另一種 CartPole 環境中,設定推力為 -1~1,包含了方向又包含了力度
  • 比如開車,方向盤的旋轉角度:-180~180
  • 四軸飛行器,控制馬達的電壓:0~15

1.2 神經網路修改

在連續動作空間中,Sarsa,Q-learning,DQN,Policy Gradient 都無法處理

所以我們要用一個替代方案:

  • 在 Policy Gradient 中,輸入狀態 s,使用策略網路輸出不同動作的概率
    • 隨機性策略:π_θ(a_t|s_t)
  • 在連續動作環境下,輸出的是一個具體的浮點數,這個浮點數代表具體的動作(比如包含了方向和力的大小)
    • 確定性策略:μ_θ(s_t)

1.3 激活函數選擇

離散動作環境:

  • 輸出層使用 softmax,確保每個動作輸出概率加總為 1

連續動作環境:

  • 輸出層使用 tanh,即輸出為 -1~1 之間的浮點數
  • 經過縮放,對應到實際動作

二、DDPG(Deep Deterministic Policy Gradient)

DDPG 演算法可以理解為 DQN 在連續動作網路中的修正版本

  • Deterministic:代表直接輸出確定性動作 a=μ(s)
  • Policy Gradient:是策略網路,但是是單步更新的策略網路

該演算法借鑒了 DQN 的兩個工程上的技巧:

  • 目標網路:target network
  • 經驗回放:replay memory

2.1 從 DQN 到 DDPG

在 DQN 的基礎上,加了一個策略網路 Policy Gradient,用來直接輸出動作值

  • 在 DQN 中,只有 Q 網路輸出 不同動作對應的 Q 值
  • 所以 DDPG 需要同時學習 2 個網路:Q 網路 和 策略網路
  • Q 網路:Q_w(s,a),其中參數為 w
  • 策略網路:a=μ_θ(s),其中參數為 θ
  • 這個結構叫做 Actor-Critic

2.2 Actor-Critic 結構

  • 策略網路:扮演 Actor 的角色,負責對外展示輸出
  • Q 網路:評論家,每個 step 對輸出動作打分,估計該動作的未來總收益(預期 Q 值)
  • 策略網路 Actor 根據評委打分,來調整策略,即更新網路參數 θ,爭取下次獲得高分
  • Critic 要根據觀眾的回饋來調整自己的打分策略,更新Q網路參數 w,目標是讓每一步獲得儘可能多的 reward(最大化未來總收益)

這種結構下,由於網路開始時候是隨機的,所以一開始評委亂打分,演員亂表演,然後根據觀眾的回饋 reward ,Critic 的打分會越來越準確,進一步推動 Actor 的表現越來越好

2.3 DDPG 的優化目標和最佳策略

在 DQN 中,我們希望網路優化後可以求解最大的 Q 值

在 DDPG 中,我們希望網路優化後可以求解最大 Q 值對應的 action

  • 策略網路的優化:最大化 Q 值,即 Loss = -Q
  • Q 網路的優化:預測 Q 值 和 目標 Q 值 之間的差別最小化
    • Q_target 是未來收益的總和
    • Q_target \approx\ r+γQ’
    • Loss = MES(Q估計, Q_{target})

2.4 借鑒 DQN 中的目標網路 target network 和經驗回放 ReplayMemory

做演算法更新最重要的一步,就是怎麼計算

我們需要優化策略網路(參數 θ):

  • 這裡的 Loss 是個複合函數
  • Loss=-Q_w(s,a)
  • 其中 a=μ_θ(s),代入上面的 Loss 中
  • 最終我們要優化的是參數 θ

還要優化 Q 網路(參數 w):

  • 要優化預測 Q 和 目標 Q 的均方差
  • 這裡的 Loss 也是個複合函數
  • Loss = MSE[Q_w(s,a),\ \ \ \ r+γQ_\overline{w}(s’,a’)]
  • 其中 a’=μ_\overline{θ}(s’)
  • 這裡存在和 DQN 網路一樣的問題,即 Q_target 的不穩定問題
  • 所以我們需要固定 Q_target

為了固定 Q_target

  • 我們需要給 Q 網路和 策略網路都搭建一個 target network
  • target_Q 和 target_P
  • 專門用於固定 Q_target
  • 其中 target_P 用於固定 a’=μ_\overline{θ}(s’)
  • 其中 target_Q 用於固定 Q_\overline{w}(s’,a’)
  • 為了用作區分,參數上面加了個橫線:\overline{θ}\overline{w}

訓練網路所需的數據是:s,a,r,s’

所以經驗池 ReplayMemory 中存儲的就是這 4 組數據

三、 PARL 庫中 DDPG 的結構

3.1 網路結構

  • model:定義 Q 網路和 策略網路,及其對應的額 target_model

  • algorithm:定義損失函數,優化 Q 網路和 策略網路

  • agent:負責演算法和環境的交互

3.2 核心函數

  • model:
    • value() 函數:形成 Q 網路的輸出
    • policy() 函數:表示 策略網路 的輸出(動作)
  • a lgorithm:
    • _critic_learn():計算 Q 網路的 Loss
    • _actor_learn():計算 策略網路 的 Loss
  • target_model:
    • 使用 deepcopy 實現

3.3 Model

包含 3 個類:

  • Model 類
  • ActorModel 類
  • CriticModel 類

調用的時候只需要調用 Model 這個類

  • Model 下的 policy() 會自動調用 ActorModel 下的 policy() 函數
  • Model 下的 value() 會自動調用 CriticModel 下的 value() 函數

注意點:

  • value() 函數的輸入是 obs 和 act,輸出的是 Q 值
  • 所以需要 layers.concat() 函數把這兩個值進行拼接,然後才可以輸入FC 網路

Model 類下的 get_actor_params() 方法:

  • 用來獲取 ActorModel 網路的 「參數名稱」
  • 這個在更新 Actor 網路的時候會需要使用
  • 其中的 parameters() 函數已經由 PARL 在底層實現好
  • 返回一個包含模型所有參數名稱的 list

3.4 Critic 網路(Q網路)更新

Algorithm 中的 _critic_learn() 函數

  • 這裡的輸入,是從經驗池中 sample 出來的一個 batch 的數據
  • 首先,通過 target_P 網路計算 next_action,這裡輸入的是 next_obs
  • 然後,把 next_action 輸入 target_Q 網路,計算 next_Q
  • 加上 reward 以後就可以求的 Q_target
  • 然後通過 Q 網路計算 預測 Q值,其與 Q_target 的均方差即為 cost

3.5 Actor 網路(策略網路)更新

這裡的計算非常簡潔,不需要用到 target model

  • 首先通過策略網路輸出 action
  • 然後通過 Q 網路輸出 Q
  • 計算 cost,即為 -Q

在這裡需要注意的是,我們只希望這裡跟新的是策略網路的參數 θ,而不是 Q 網路的參數 w

  • 所以在最小化損失函數的過程中,要設定需要優化的參數是哪些
  • 這裡就用到了 get_actor_params() 函數
  • 獲得 actor 網路的參數名稱,僅更新 actor 網路的參數

3.6 Target network 參數軟更新

在 DQN 中,target 網路採用的是硬跟新,而在 DDPG 中,採用更平緩的軟更新

  • 設置參數 τ 來控制更新幅度
  • 其中 w 和 θ 表示新參數
  • 若 τ 取 0.001,即每次新參數只取 0.1% 的權重
  • 這是工程上的一點小技巧
  • PARL 庫中的 sync_weights_to() 函數可以進行參數更新

四、程式碼詳解

強化學習演算法 DDPG 解決 CartPole 問題,程式碼逐條詳解