matlab强化学习Sarsa与Sarsa(lambda)对比

  • 2019 年 10 月 4 日
  • 筆記

Sarsa lambda

适用于回合型环境,要等到回合结束, 才开始对本回合所经历的所有步都添加更新, 但是这所有的步都是和宝藏有关系的, 都是为了得到宝藏需要学习的步, 所以每一步在下回合被选中的几率又高了一些

当 lambda 取0, 就变成了 Sarsa 的单步更新

% 强化学习Sarsa lambda

ccc

% rng('default');

env=two_dimensional_env(4,4,0.01);

two_dimensional_rl=rl_q_table(env.actions,0.9,0.1,0.9,0.9);

% pause(2)

for episode =1:env.max_episodes

show = 1;

if mod(episode,10)

show = 0;

end

env = env.reset();

env.render(show);

A = two_dimensional_rl.choose_action(env.agent);

two_dimensional_rl = two_dimensional_rl.reset();

while 1

env = env.step(A); % 采取动作获得状态和奖励

A_ = two_dimensional_rl.choose_action(env.observation);

two_dimensional_rl=two_dimensional_rl.learn(env, A, A_); % 更新

two_dimensional_rl.dump();

env.agent=env.observation;

A=A_;

env.render(show);

if env.done

break

end

end

end

强化学习方法对象改动稍稍多一些

classdef rl_q_table

% 强化学习逻辑

properties

q_table

actions

epsilon

alpha

gamma

trace_decay

eligibility_trace

end

methods

function obj = rl_q_table(actions,epsilon,alpha,gamma,trace_decay)

% 初始化

obj.actions=actions;

obj.epsilon=epsilon;

obj.alpha=alpha;

obj.gamma=gamma;

obj.trace_decay=trace_decay;

obj.q_table = containers.Map();

obj.eligibility_trace = containers.Map();

end

function dump(obj)

keySet = keys(obj.q_table);

len=length(keySet);

if len<=0

return

end

disp('—————————')

for i=1:len

disp([keySet{i} ':' StrHelper.arr2str(obj.q_table(keySet{i}))])

end

% keySet_trace = keys(obj.eligibility_trace);

% len=length(keySet_trace);

% if len<=0

% return

% end

% disp('*******************************')

% for i=1:len

% disp([keySet_trace{i} ':' StrHelper.arr2str(obj.eligibility_trace(keySet_trace{i}))])

% end

end

function obj=reset(obj)

keySet_trace = keys(obj.eligibility_trace);

len=length(keySet_trace);

if len>0

for i=1:len

temp=obj.eligibility_trace(keySet_trace{i});

temp(:)=0;

obj.eligibility_trace(keySet_trace{i})=temp;

end

end

end

function table_ling=find_line(obj,state)

agent_str = StrHelper.arr2str(state);

if ~isKey(obj.q_table,agent_str)

obj.q_table(agent_str) = zeros(1,length(obj.actions));

obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));

end

table_ling = obj.q_table(agent_str);

end

function table_ling=find_trace_line(obj,state)

agent_str = StrHelper.arr2str(state);

if ~isKey(obj.eligibility_trace,agent_str)

obj.eligibility_trace(agent_str) = zeros(1,length(obj.actions));

end

table_ling = obj.eligibility_trace(agent_str);

end

function obj=learn(obj,env, A, A_)

q_predict_arr = obj.find_line(env.agent);

q_predict = q_predict_arr(A);

if env.done ~= 1

line = obj.find_line(env.observation);

q_target = env.reward + obj.gamma * line(A_); % 没有结束

else

q_target = env.reward; % 一局结束了

end

% 更新QLearning table

q_trace_arr = obj.find_trace_line(env.agent);

q_trace_arr(:) = 0;

q_trace_arr(A) = q_trace_arr(A) + 1;

obj.eligibility_trace(StrHelper.arr2str(env.agent)) = q_trace_arr;

key_cell = keys(obj.q_table);

error = q_target – q_predict;

for i=1:length(obj.q_table)

obj.q_table(key_cell{i})=obj.q_table(key_cell{i})+…

obj.eligibility_trace(key_cell{i})*obj.alpha*error;

obj.eligibility_trace(key_cell{i})=…

obj.eligibility_trace(key_cell{i})*obj.gamma*obj.trace_decay;

end

end

function action_name = choose_action(obj,state)

% 选择一个动作

state_actions = obj.find_line(state);% 取出这一步的概率

if (rand() > obj.epsilon) || (all(state_actions == 0))

% 初始时随机选择

action_name = obj.actions(randi(length(obj.actions)));

else % 贪心选择

[~,I] = max(state_actions);

max_index = state_actions==state_actions(I);

if sum(max_index)>1

action_name = obj.actions(max_index);

action_name = action_name(randi(length(action_name)));

else

action_name = obj.actions(I);

end

% 选概率大的

end

end

end

end

如果觉得训练慢、可以让render方法每隔10次或50次调用

相关工具方法在

https://gitee.com/sickle12138/MatGamer

想获取知识但是一些网站却打不开、视频点开来只能看到一行网址,输入我的邀请码 MCGK3X 你我都能获得额外三个月的蓝灯专业版!畅通无阻、立即下载https://github.com/getlantern/forum

帮你学MatLab

微信号:MatLab_helper