matlab強化學習Sarsa與Sarsa(lambda)對比
- 2019 年 10 月 4 日
- 筆記

Sarsa lambda
適用於回合型環境,要等到回合結束, 才開始對本回合所經歷的所有步都添加更新, 但是這所有的步都是和寶藏有關係的, 都是為了得到寶藏需要學習的步, 所以每一步在下回合被選中的幾率又高了一些
當 lambda 取0, 就變成了 Sarsa 的單步更新
% 強化學習Sarsa lambda
ccc
% rng('default');
env=two_dimensional_env(4,4,0.01);
two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9,0.9);
% pause(2)
for episode =1:env.max_episodes
show = 1;
if mod(episode,10)
show = 0;
end
env = env.reset();
env.render(show);
A = two_dimensional_rl.choose_action(env.agent);
two_dimensional_rl = two_dimensional_rl.reset();
while 1
env = env.step(A); % 採取動作獲得狀態和獎勵
A_ = two_dimensional_rl.choose_action(env.observation);
two_dimensional_rl=two_dimensional_rl.learn(env, A, A_); % 更新
two_dimensional_rl.dump();
env.agent=env.observation;
A=A_;
env.render(show);
if env.done
break
end
end
end
強化學習方法對象改動稍稍多一些
classdef rl_q_table
% 強化學習邏輯
properties
q_table
actions
epsilon
alpha
gamma
trace_decay
eligibility_trace
end
methods
function obj = rl_q_table(actions,epsilon,alpha,gamma,trace_decay)
% 初始化
obj.actions=actions;
obj.epsilon=epsilon;
obj.alpha=alpha;
obj.gamma=gamma;
obj.trace_decay=trace_decay;
obj.q_table = containers.Map();
obj.eligibility_trace = containers.Map();
end
function dump(obj)
keySet = keys(obj.q_table);
len=length(keySet);
if len<=0
return
end
disp('—————————')
for i=1:len
disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])
end
% keySet_trace = keys(obj.eligibility_trace);
% len=length(keySet_trace);
% if len<=0
% return
% end
% disp('*******************************')
% for i=1:len
% disp([keySet_trace{i} ':' StrHelper.arr2str(obj.eligibility_trace(keySet_trace{i}))])
% end
end
function obj=reset(obj)
keySet_trace = keys(obj.eligibility_trace);
len=length(keySet_trace);
if len>0
for i=1:len
temp=obj.eligibility_trace(keySet_trace{i});
temp(:)=0;
obj.eligibility_trace(keySet_trace{i})=temp;
end
end
end
function table_ling=find_line(obj,state)
agent_str = StrHelper.arr2str(state);
if ~isKey(obj.q_table,agent_str)
obj.q_table(agent_str) = zeros(1,length(obj.actions));
obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));
end
table_ling = obj.q_table(agent_str);
end
function table_ling=find_trace_line(obj,state)
agent_str = StrHelper.arr2str(state);
if ~isKey(obj.eligibility_trace,agent_str)
obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));
end
table_ling = obj.eligibility_trace(agent_str);
end
function obj=learn(obj,env, A, A_)
q_predict_arr = obj.find_line(env.agent);
q_predict = q_predict_arr(A);
if env.done ~= 1
line = obj.find_line(env.observation);
q_target = env.reward + obj.gamma * line(A_); % 沒有結束
else
q_target = env.reward; % 一局結束了
end
% 更新QLearning table
q_trace_arr = obj.find_trace_line(env.agent);
q_trace_arr(:) = 0;
q_trace_arr(A) = q_trace_arr(A) + 1;
obj.eligibility_trace(StrHelper.arr2str(env.agent)) = q_trace_arr;
key_cell = keys(obj.q_table);
error = q_target – q_predict;
for i=1:length(obj.q_table)
obj.q_table(key_cell{i})=obj.q_table(key_cell{i})+…
obj.eligibility_trace(key_cell{i})*obj.alpha*error;
obj.eligibility_trace(key_cell{i})=…
obj.eligibility_trace(key_cell{i})*obj.gamma*obj.trace_decay;
end
end
function action_name = choose_action(obj,state)
% 選擇一個動作
state_actions = obj.find_line(state);% 取出這一步的概率
if (rand() > obj.epsilon) || (all(state_actions == 0))
% 初始時隨機選擇
action_name = obj.actions(randi(length(obj.actions)));
else % 貪心選擇
[~,I] = max(state_actions);
max_index = state_actions==state_actions(I);
if sum(max_index)>1
action_name = obj.actions(max_index);
action_name = action_name(randi(length(action_name)));
else
action_name = obj.actions(I);
end
% 選概率大的
end
end
end
end
如果覺得訓練慢、可以讓render方法每隔10次或50次調用
相關工具方法在
https://gitee.com/sickle12138/MatGamer
想獲取知識但是一些網站卻打不開、影片點開來只能看到一行網址,輸入我的邀請碼 MCGK3X 你我都能獲得額外三個月的藍燈專業版!暢通無阻、立即下載https://github.com/getlantern/forum
幫你學MatLab
微訊號:MatLab_helper