matlab強化學習Sarsa與Sarsa(lambda)對比

  • 2019 年 10 月 4 日
  • 筆記

Sarsa lambda

適用於回合型環境,要等到回合結束, 才開始對本回合所經歷的所有步都添加更新, 但是這所有的步都是和寶藏有關係的, 都是為了得到寶藏需要學習的步, 所以每一步在下回合被選中的幾率又高了一些

當 lambda 取0, 就變成了 Sarsa 的單步更新

% 強化學習Sarsa lambda

ccc

% rng('default');

env=two_dimensional_env(4,4,0.01);

two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9,0.9);

% pause(2)

for episode =1:env.max_episodes

show = 1;

if mod(episode,10)

show = 0;

end

env = env.reset();

env.render(show);

A = two_dimensional_rl.choose_action(env.agent);

two_dimensional_rl = two_dimensional_rl.reset();

while 1

env = env.step(A); % 採取動作獲得狀態和獎勵

A_ = two_dimensional_rl.choose_action(env.observation);

two_dimensional_rl=two_dimensional_rl.learn(env, A, A_); % 更新

two_dimensional_rl.dump();

env.agent=env.observation;

A=A_;

env.render(show);

if env.done

break

end

end

end

強化學習方法對象改動稍稍多一些

classdef rl_q_table

% 強化學習邏輯

properties

q_table

actions

epsilon

alpha

gamma

trace_decay

eligibility_trace

end

methods

function obj = rl_q_table(actions,epsilon,alpha,gamma,trace_decay)

% 初始化

obj.actions=actions;

obj.epsilon=epsilon;

obj.alpha=alpha;

obj.gamma=gamma;

obj.trace_decay=trace_decay;

obj.q_table = containers.Map();

obj.eligibility_trace = containers.Map();

end

function dump(obj)

keySet = keys(obj.q_table);

len=length(keySet);

if len<=0

return

end

disp('—————————')

for i=1:len

disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])

end

% keySet_trace = keys(obj.eligibility_trace);

% len=length(keySet_trace);

% if len<=0

% return

% end

% disp('*******************************')

% for i=1:len

% disp([keySet_trace{i} ':' StrHelper.arr2str(obj.eligibility_trace(keySet_trace{i}))])

% end

end

function obj=reset(obj)

keySet_trace = keys(obj.eligibility_trace);

len=length(keySet_trace);

if len>0

for i=1:len

temp=obj.eligibility_trace(keySet_trace{i});

temp(:)=0;

obj.eligibility_trace(keySet_trace{i})=temp;

end

end

end

function table_ling=find_line(obj,state)

agent_str = StrHelper.arr2str(state);

if ~isKey(obj.q_table,agent_str)

obj.q_table(agent_str) = zeros(1,length(obj.actions));

obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));

end

table_ling = obj.q_table(agent_str);

end

function table_ling=find_trace_line(obj,state)

agent_str = StrHelper.arr2str(state);

if ~isKey(obj.eligibility_trace,agent_str)

obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));

end

table_ling = obj.eligibility_trace(agent_str);

end

function obj=learn(obj,env, A, A_)

q_predict_arr = obj.find_line(env.agent);

q_predict = q_predict_arr(A);

if env.done ~= 1

line = obj.find_line(env.observation);

q_target = env.reward + obj.gamma * line(A_); % 沒有結束

else

q_target = env.reward; % 一局結束了

end

% 更新QLearning table

q_trace_arr = obj.find_trace_line(env.agent);

q_trace_arr(:) = 0;

q_trace_arr(A) = q_trace_arr(A) + 1;

obj.eligibility_trace(StrHelper.arr2str(env.agent)) = q_trace_arr;

key_cell = keys(obj.q_table);

error = q_target – q_predict;

for i=1:length(obj.q_table)

obj.q_table(key_cell{i})=obj.q_table(key_cell{i})+…

obj.eligibility_trace(key_cell{i})*obj.alpha*error;

obj.eligibility_trace(key_cell{i})=…

obj.eligibility_trace(key_cell{i})*obj.gamma*obj.trace_decay;

end

end

function action_name = choose_action(obj,state)

% 選擇一個動作

state_actions = obj.find_line(state);% 取出這一步的概率

if (rand() > obj.epsilon) || (all(state_actions == 0))

% 初始時隨機選擇

action_name = obj.actions(randi(length(obj.actions)));

else % 貪心選擇

[~,I] = max(state_actions);

max_index = state_actions==state_actions(I);

if sum(max_index)>1

action_name = obj.actions(max_index);

action_name = action_name(randi(length(action_name)));

else

action_name = obj.actions(I);

end

% 選概率大的

end

end

end

end

如果覺得訓練慢、可以讓render方法每隔10次或50次調用

相關工具方法在

https://gitee.com/sickle12138/MatGamer

想獲取知識但是一些網站卻打不開、視頻點開來只能看到一行網址,輸入我的邀請碼 MCGK3X 你我都能獲得額外三個月的藍燈專業版!暢通無阻、立即下載https://github.com/getlantern/forum

幫你學MatLab

微信號:MatLab_helper