MATLAB强化学习训练simulink模型

  • 2020 年 2 月 11 日
  • 筆記

simulink可以方便地建立物理域模型,这是一个简单的倒立摆,同样可以使用MATLAB的强化学习工具箱进行训练

%% 读取环境

ccc

mdl = 'rlCartPoleSimscapeModel';

open_system(mdl)

env = rlPredefinedEnv('CartPoleSimscapeModel-Continuous');

obsInfo = getObservationInfo(env);

numObservations = obsInfo.Dimension(1);

actInfo = getActionInfo(env);

%%

Ts = 0.02;

Tf = 25;

rng(0)

%% 初始化agent

statePath = [

imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')

fullyConnectedLayer(128,'Name','CriticStateFC1')

reluLayer('Name','CriticRelu1')

fullyConnectedLayer(200,'Name','CriticStateFC2')];

actionPath = [

imageInputLayer([1 1 1],'Normalization','none','Name','action')

fullyConnectedLayer(200,'Name','CriticActionFC1','BiasLearnRateFactor',0)];

commonPath = [

additionLayer(2,'Name','add')

reluLayer('Name','CriticCommonRelu')

fullyConnectedLayer(1,'Name','CriticOutput')];

criticNetwork = layerGraph(statePath);

criticNetwork = addLayers(criticNetwork,actionPath);

criticNetwork = addLayers(criticNetwork,commonPath);

criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');

criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');

figure

plot(criticNetwork)

criticOptions = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);

critic = rlRepresentation(criticNetwork,obsInfo,actInfo,…

'Observation',{'observation'},'Action',{'action'},criticOptions);

actorNetwork = [

imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')

fullyConnectedLayer(128,'Name','ActorFC1')

reluLayer('Name','ActorRelu1')

fullyConnectedLayer(200,'Name','ActorFC2')

reluLayer('Name','ActorRelu2')

fullyConnectedLayer(1,'Name','ActorFC3')

tanhLayer('Name','ActorTanh1')

scalingLayer('Name','ActorScaling','Scale',max(actInfo.UpperLimit))];

actorOptions = rlRepresentationOptions('LearnRate',5e-04,'GradientThreshold',1);

actor = rlRepresentation(actorNetwork,obsInfo,actInfo,…

'Observation',{'observation'},'Action',{'ActorScaling'},actorOptions);

agentOptions = rlDDPGAgentOptions(…

'SampleTime',Ts,…

'TargetSmoothFactor',1e-3,…

'ExperienceBufferLength',1e6,…

'MiniBatchSize',128);

agentOptions.NoiseOptions.Variance = 0.4;

agentOptions.NoiseOptions.VarianceDecayRate = 1e-5;

agent = rlDDPGAgent(actor,critic,agentOptions);

%% 设置训练参数

maxepisodes = 2000;

maxsteps = ceil(Tf/Ts);

trainingOptions = rlTrainingOptions(…

'MaxEpisodes',maxepisodes,…

'MaxStepsPerEpisode',maxsteps,…

'ScoreAveragingWindowLength',5,…

'Verbose',false,…

'Plots','training-progress',…

'StopTrainingCriteria','AverageReward',…

'StopTrainingValue',-400,…

'SaveAgentCriteria','EpisodeReward',…

'SaveAgentValue',-400);

%% 并行学习设置

trainOpts.UseParallel = true;

trainOpts.ParallelizationOptions.Mode = "async";

trainOpts.ParallelizationOptions.DataToSendFromWorkers = "Gradients";

trainOpts.ParallelizationOptions.StepsUntilDataIsSent = -1;

%% 训练

trainingStats = train(agent,env,trainingOptions);

%% 结果展示

simOptions = rlSimulationOptions('MaxSteps',500);

experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward);

% bdclose(mdl)

%关闭simulink模型

相关文件下载链接:

https://pan.baidu.com/s/1O1O1PaloLpaOFde1PNI_1w

提取码:ngou

畅通无阻、立即下载https://github.com/getlantern/forum

帮你学MatLab