MATLAB强化学习toolbox

原创
2019/10/15 22:22
阅读数 60

新版本MATLAB提供了Reinforcement Learning Toolbox可以方便地建立二维基础网格环境、设置起点、目标、障碍,以及各种agent模型

这是Q-learning的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 根据障碍设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% Q-learning训练参数初始化

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlQAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = .04;

qAgent = rlQAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

rng(0)

trainingStats = train(qAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(qAgent,env)

这是SARSA的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% %% SARSA参数初始化

rng(0)

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlSARSAAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;

sarsaAgent = rlSARSAAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

trainingStats = train(sarsaAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(sarsaAgent,env)

想获取知识但是一些网站却打不开、视频点开来只能看到一行网址,输入我的邀请码 MCGK3X 你我都能获得额外三个月的蓝灯专业版!畅通无阻、立即下载https://github.com/getlantern/forum

帮你学MatLab

微信号:MatLab_helper

长按识别二维码关注我们

本文分享自微信公众号 - 帮你学MatLab(MatLab_helper)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
在线直播报名
返回顶部
顶部