強化學習入門(五)連續動作空間內,使用DDPG求解強化學習問題
本文內容源自百度強化學習 7 日入門課程學習整理
感謝百度 PARL 團隊李科澆老師的課程講解
1.1 區別
1.2 神經網路修改
1.3 激活函數選擇
二、DDPG(Deep Deterministic Policy Gradient)
2.1 從 DQN 到 DDPG
2.2 Actor-Critic 結構
2.3 DDPG 的優化目標和最佳策略
2.4 借鑒 DQN 中的目標網路 target network 和經驗回放 ReplayMemory
三、 PARL 庫中 DDPG 的結構
3.1 網路結構
3.2 核心函數
3.3 Model
3.4 Critic 網路(Q網路)更新
3.5 Actor 網路(策略網路)更新
3.6 Target network 參數軟更新
四、程式碼詳解
一、離散動作 VS 連續動作
1.1 區別
離散動作:
- 動作都是可數的
- 比如在 CartPole 環境中,向左推動小車,或者向右推動(力量是不變的)
- 比如在 FrozenLake 環境中,向上下左右 4 個方向移動
- 在 Atari 的 Pong 環境中,球拍上下移動
連續動作:
- 動作是連續的浮點數
- 比如在另一種 CartPole 環境中,設定推力為 -1~1,包含了方向又包含了力度
- 比如開車,方向盤的旋轉角度:-180~180
- 四軸飛行器,控制馬達的電壓:0~15
1.2 神經網路修改
在連續動作空間中,Sarsa,Q-learning,DQN,Policy Gradient 都無法處理
所以我們要用一個替代方案:
- 在 Policy Gradient 中,輸入狀態 s,使用策略網路輸出不同動作的概率
- 隨機性策略:π_θ(a_t|s_t)
- 在連續動作環境下,輸出的是一個具體的浮點數,這個浮點數代表具體的動作(比如包含了方向和力的大小)
- 確定性策略:μ_θ(s_t)
1.3 激活函數選擇
離散動作環境:
- 輸出層使用 softmax,確保每個動作輸出概率加總為 1
連續動作環境:
- 輸出層使用 tanh,即輸出為 -1~1 之間的浮點數
- 經過縮放,對應到實際動作
二、DDPG(Deep Deterministic Policy Gradient)
DDPG 演算法可以理解為 DQN 在連續動作網路中的修正版本
- Deterministic:代表直接輸出確定性動作 a=μ(s)
- Policy Gradient:是策略網路,但是是單步更新的策略網路
該演算法借鑒了 DQN 的兩個工程上的技巧:
- 目標網路:target network
- 經驗回放:replay memory
2.1 從 DQN 到 DDPG
在 DQN 的基礎上,加了一個策略網路 Policy Gradient,用來直接輸出動作值
- 在 DQN 中,只有 Q 網路輸出 不同動作對應的 Q 值
- 所以 DDPG 需要同時學習 2 個網路:Q 網路 和 策略網路
- Q 網路:Q_w(s,a),其中參數為 w
- 策略網路:a=μ_θ(s),其中參數為 θ
- 這個結構叫做 Actor-Critic
2.2 Actor-Critic 結構
- 策略網路:扮演 Actor 的角色,負責對外展示輸出
- Q 網路:評論家,每個 step 對輸出動作打分,估計該動作的未來總收益(預期 Q 值)
- 策略網路 Actor 根據評委打分,來調整策略,即更新網路參數 θ,爭取下次獲得高分
- Critic 要根據觀眾的回饋來調整自己的打分策略,更新Q網路參數 w,目標是讓每一步獲得儘可能多的 reward(最大化未來總收益)
這種結構下,由於網路開始時候是隨機的,所以一開始評委亂打分,演員亂表演,然後根據觀眾的回饋 reward ,Critic 的打分會越來越準確,進一步推動 Actor 的表現越來越好
2.3 DDPG 的優化目標和最佳策略
在 DQN 中,我們希望網路優化後可以求解最大的 Q 值
在 DDPG 中,我們希望網路優化後可以求解最大 Q 值對應的 action
- 策略網路的優化:最大化 Q 值,即 Loss = -Q
- Q 網路的優化:預測 Q 值 和 目標 Q 值 之間的差別最小化
- Q_target 是未來收益的總和
- Q_target \approx\ r+γQ’
- Loss = MES(Q估計, Q_{target})
2.4 借鑒 DQN 中的目標網路 target network 和經驗回放 ReplayMemory
做演算法更新最重要的一步,就是怎麼計算
我們需要優化策略網路(參數 θ):
- 這裡的 Loss 是個複合函數
- Loss=-Q_w(s,a)
- 其中 a=μ_θ(s),代入上面的 Loss 中
- 最終我們要優化的是參數 θ
還要優化 Q 網路(參數 w):
- 要優化預測 Q 和 目標 Q 的均方差
- 這裡的 Loss 也是個複合函數
- Loss = MSE[Q_w(s,a),\ \ \ \ r+γQ_\overline{w}(s’,a’)]
- 其中 a’=μ_\overline{θ}(s’)
- 這裡存在和 DQN 網路一樣的問題,即 Q_target 的不穩定問題
- 所以我們需要固定 Q_target
為了固定 Q_target
- 我們需要給 Q 網路和 策略網路都搭建一個 target network
- target_Q 和 target_P
- 專門用於固定 Q_target
- 其中 target_P 用於固定 a’=μ_\overline{θ}(s’)
- 其中 target_Q 用於固定 Q_\overline{w}(s’,a’)
- 為了用作區分,參數上面加了個橫線:\overline{θ} 和 \overline{w}
訓練網路所需的數據是:s,a,r,s’
所以經驗池 ReplayMemory 中存儲的就是這 4 組數據
三、 PARL 庫中 DDPG 的結構
3.1 網路結構
-
model:定義 Q 網路和 策略網路,及其對應的額 target_model
-
algorithm:定義損失函數,優化 Q 網路和 策略網路
-
agent:負責演算法和環境的交互
3.2 核心函數
- model:
- value() 函數:形成 Q 網路的輸出
- policy() 函數:表示 策略網路 的輸出(動作)
- a lgorithm:
- _critic_learn():計算 Q 網路的 Loss
- _actor_learn():計算 策略網路 的 Loss
- target_model:
- 使用 deepcopy 實現
3.3 Model
包含 3 個類:
- Model 類
- ActorModel 類
- CriticModel 類
調用的時候只需要調用 Model 這個類
- Model 下的 policy() 會自動調用 ActorModel 下的 policy() 函數
- Model 下的 value() 會自動調用 CriticModel 下的 value() 函數
注意點:
- value() 函數的輸入是 obs 和 act,輸出的是 Q 值
- 所以需要 layers.concat() 函數把這兩個值進行拼接,然後才可以輸入FC 網路
Model 類下的 get_actor_params() 方法:
- 用來獲取 ActorModel 網路的 「參數名稱」
- 這個在更新 Actor 網路的時候會需要使用
- 其中的 parameters() 函數已經由 PARL 在底層實現好
- 返回一個包含模型所有參數名稱的 list
3.4 Critic 網路(Q網路)更新
Algorithm 中的 _critic_learn() 函數
- 這裡的輸入,是從經驗池中 sample 出來的一個 batch 的數據
- 首先,通過 target_P 網路計算 next_action,這裡輸入的是 next_obs
- 然後,把 next_action 輸入 target_Q 網路,計算 next_Q
- 加上 reward 以後就可以求的 Q_target
- 然後通過 Q 網路計算 預測 Q值,其與 Q_target 的均方差即為 cost
3.5 Actor 網路(策略網路)更新
這裡的計算非常簡潔,不需要用到 target model
- 首先通過策略網路輸出 action
- 然後通過 Q 網路輸出 Q
- 計算 cost,即為 -Q
在這裡需要注意的是,我們只希望這裡跟新的是策略網路的參數 θ,而不是 Q 網路的參數 w
- 所以在最小化損失函數的過程中,要設定需要優化的參數是哪些
- 這裡就用到了 get_actor_params() 函數
- 獲得 actor 網路的參數名稱,僅更新 actor 網路的參數
3.6 Target network 參數軟更新
在 DQN 中,target 網路採用的是硬跟新,而在 DDPG 中,採用更平緩的軟更新
- 設置參數 τ 來控制更新幅度
- 其中 w 和 θ 表示新參數
- 若 τ 取 0.001,即每次新參數只取 0.1% 的權重
- 這是工程上的一點小技巧
- PARL 庫中的 sync_weights_to() 函數可以進行參數更新
四、程式碼詳解
強化學習演算法 DDPG 解決 CartPole 問題,程式碼逐條詳解