如何訓練AI玩飛機大戰遊戲

  • 2019 年 10 月 8 日
  • 筆記

本文轉載自1024開發者社區

雖然沒有Google強大的集和DeepMind變態的演算法的團隊,但基於深度強化學習(Deep Q Network DQN )的自製小遊戲AI效果同樣很贊。先上效果圖:

下面分四個部分,具體給大家介紹。

/1/背景介紹

2013年DeepMind團隊發表論文「Playing Atari with Deep Reinforcement Learning」,用Q-Network模型成功讓AI玩起了Atari系列遊戲。並於2015年在《Nature》上發表了一篇升級版,「Human-level control through deep reinforcement learning」,自此,在這類遊戲領域,人已經無法超過機器了。AI玩遊戲的姿勢是這樣的:

後來的故事大家都很熟悉了,AlphaGo擊敗世界冠軍,星際爭霸2職業選手也被打敗,連大家接觸較多的王者榮耀也不能倖免。

/2/深度強化學習模型

看完了輕鬆的部分,下面簡單介紹一下模型。DQN是DRL的一種演算法,它將卷積神經網路(CNN)和Q-Learning結合起來。

Q-learning是強化學習的一種,原理圖如下:

也就是Agent在觀察得到當前的狀態state和回報reward的基礎上,選取輸出一個動作action,進而影響環境,使環境狀態和回報都產生變化。通過不斷循環讓Agent學習如何在環境中獲得更高的回報。

卷積神經網路CNN是影像處理領域非常經典的神經網路模型,在本模型中,輸入是原始影像數據,輸出為每個動作action對應的評估值。

因此DQN總體結構是這樣的:

圖比較簡單,但原理很清晰,是將Agent中的模型用CNN來代替,環境的State為遊戲介面截圖,輸出為AI的動作,在飛機大戰中就是飛機向左、向右還是不動。回報reward具體為,在一次循環中沒有被擊中為0.1,被擊中為-1,擊中敵機為1。圖中回放記憶單元、當前網路和目標網路都是為了將CNN這種需要大量樣本的監督學習融合在強化學習模型中的手段。篇幅限制這裡只是概述性的介紹,後期會專門講。

/3/模型實現

3.1程式的總體結構

程式主函數在PlaneDQN.py中,與DQN模型相關的函數在BrainDQN_Nature.py中,遊戲模型在game文件夾中,訓練過程保存的訓練值在saved_networks文件夾中。

3.2主函數搭建

大家注意看while循環里的結構,其實非常明確:

  • getaction()為在當前的Q值下選取動作
  • framestep()為運行環境,並輸出觀測值
  • process()為對影像數據進行處理的函數
  • setPerception()根據影像和回報,對網路進行訓練
def playPlane():     # Step 1: 初始化DQN     actions = 3     brain = BrainDQN(actions)     # Step 2: 初始化遊戲     plane = game.GameState()     # Step 3: 玩遊戲     # Step 3.1: 獲取初始動作     action0 = np.array([1,0,0])  # [1,0,0]do nothing,[0,1,0]left,[0,0,1]right     observation0, reward0, terminal = plane.frame_step(action0)     observation0 = cv2.cvtColor(cv2.resize(observation0, (80, 80)), cv2.COLOR_BGR2GRAY)     ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY)     brain.setInitState(observation0)     # Step 3.2: 開始遊戲     while 1!= 0:        action = brain.getAction()        nextObservation,reward,terminal = plane.frame_step(action)        nextObservation = preprocess(nextObservation)        brain.setPerception(nextObservation,action,reward,terminal)

3.3 遊戲類GameState和framestep

通過pygame實現遊戲介面的搭建,分別建立子彈類、玩家類、敵機類和遊戲類,結構程式碼所示。

class Bullet(pygame.sprite.Sprite):      def __init__(self, bullet_img, init_pos):      def move(self):  # 我方飛機類  class Player(pygame.sprite.Sprite):      def __init__(self, plane_img, player_rect, init_pos):      def shoot(self, bullet_img):           def moveLeft(self):      def moveRight(self):  # 敵方飛機類  class Enemy(pygame.sprite.Sprite):      def __init__(self, enemy_img, enemy_down_imgs, init_pos):      def move(self):    class GameState:      def __init__(self):      def frame_step(self, input_actions):          if input_actions[0] == 1 or input_actions[1]== 1 or input_actions[2]== 1:  # 檢查輸入正常              if input_actions[0] == 0 and input_actions[1] == 1 and input_actions[2] == 0:                  self.player.moveLeft()              elif input_actions[0] == 0 and input_actions[1] == 0 and input_actions[2] == 1:                  self.player.moveRight()              else:                  pass          else:              raise ValueError('Multiple input actions!')          image_data = pygame.surfarray.array3d(pygame.display.get_surface())          pygame.display.update()          clock = pygame.time.Clock()          clock.tick(30)          return image_data, reward, terminal

其中GameState中的framestep()函數,是整個DQN運行一次使環境發生變化的基礎函數,該函數運行一次,會根據inputaction進行動作實施,接著會在該時段對介面上的元素進行移動,並判斷是否撞擊。最後通過get_surface獲取介面影像,最後返迴環境的image_data,reward和遊戲是否停止的terminal。本文遊戲效果圖為:

為提高模型收斂速度,在實際運行時將背景圖片去掉。

3.4 DQN模型類

該部分為DQN模型的核心,主要有根據參數建立CNN網路的createQNetwork(),進行模型訓練的trainQNetwork(),進行動作選擇的getAction()。

class BrainDQN:     def __init__(self,actions):     def createQNetwork(self):        return stateInput,QValue,W_conv1,b_conv1,W_conv2,b_conv2,W_conv3,b_conv3,W_fc1,b_fc1,W_fc2,b_fc2     def copyTargetQNetwork(self):        self.session.run(self.copyTargetQNetworkOperation)     def createTrainingMethod(self):     def trainQNetwork(self):     def getAction(self):        return action     def setInitState(self,observation):        self.currentState = np.stack((observation, observation, observation, observation), axis = 2)     def weight_variable(self,shape):        return tf.Variable(initial)     def bias_variable(self,shape):        return tf.Variable(initial)     def conv2d(self,x, W, stride):        return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME")     def max_pool_2x2(self,x):        return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")

3.5影像處理

影像預處理調用cv2庫函數,對影像進行大小和灰度處理。

def preprocess(observation):     observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)#灰度轉化     ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)      return np.reshape(observation,(80,80,1))

/4/環境搭建

  • 系統:Ubuntu16.04、win10
  • Python3.5
  • pygame 1.9.4
  • TensorFlow1.11(GPU版)
  • OpenCV-Python

公眾號中回復「AI飛機」,獲取程式碼,包含訓練500000次的結果。本程式對硬體要求不高,顯示記憶體2GB以上就可運行。