尽量使用sarsa或者Q-learning算法,基本就是实现实例中的东西,MATLAB
clc;
clear all;
% define state
R=ones(6,6)*-inf;
R(1,1)=0;R(1,2)=0;R(1,3)=0;R(1,5)=0;R(1,6)=0;
R(2,1)=0;R(2,2)=0;R(2,4)=0;R(2,5)=0;R(2,6)=0;
R(3,1)=0;R(3,3)=0;R(3,4)=0;R(3,5)=0;
R(4,2)=0;R(4,3)=0;R(4,4)=0;R(4,6)=0;
R(5,1)=0;R(5,2)=0;R(5,3)=0;R(5,5)=0;
R(6,1)=0;R(6,2)=0;R(6,4)=0;R(6,6)=0;
R(4,1)=-28;R(3,2)=25;R(2,3)=-25;
R(1,4)=28;R(6,3)=-26;R(5,4)=53;
R(4,5)=-53;R(3,6)=26;R(6,5)=-26;R(5,6)=26;
% reinforcement learning parameters 强化学习参数
gamma=0.9;
q=zeros(size(R)); % q matrix Q矩阵
q1=ones(size(R))*inf; % previous q matrix 以前的Q矩阵
count=0;
% visualize obstacle 可视化障碍
axis([0,10,0,6]);
hold on;
%plot([0,3],[3,3],'g','linewidth',2);
%plot([6,10],[3,3],'g','linewidth',2);
%plot([3,3],[2,3],'g','linewidth',2);
%plot([6,6],[2,3],'g','linewidth',2);
%plot([4,4],[0,2],'g','linewidth',2);
%plot([5,5],[0,2],'g','linewidth',2);
%plot([3,4],[2,2],'g','linewidth',2);
%plot([5,6],[2,2],'g','linewidth',2);
% intial state 初始状态
y=randperm(6);
state=y(1);
% q learning
tic
for episode=0:50000
qma=max(q(state,:));
if qma~=0
x=find(q(state,:)==qma);
else
x=find(R(state,:)>=0);
end
% choose action
if size(x,1)>0
x1=RandomPermutation(x);
x1=x1(1);
end
% update q matrix
qMax=max(q,[],2);
q(state,x1)=R(state,x1)+gamma*qMax(x1);
%q价值函数矩阵
%R是rt+1
Y(1i)=5.5-floor((x1-1)/10);
X(1i)=0.5+rem(x1-1,10);
% visualization 可视化
%A=plot([X(1i)-0.5,X(1i)+0.5],[Y(1i)-0.5,Y(1i)-0.5],'r-','linewidth',2);
%B=plot([X(1i)-0.5,X(1i)+0.5],[Y(1i)+0.5,Y(1i)+0.5],'r-','linewidth',2);
%C=plot([X(1i)-0.5,X(1i)-0.5],[Y(1i)-0.5,Y(1i)+0.5],'r-','linewidth',2);
%D=plot([X(1i)+0.5,X(1i)+0.5],[Y(1i)-0.5,Y(1i)+0.5],'r-','linewidth',2);
%pause(0.05);
% break if converged: small deviation on q for 1000 consecutive
% 如果收敛中断:连续1000次在Q上的偏差
if sum(sum(abs(q1-q)))<0.0001 && sum(sum(q))>190
if count>500
episode; % report last episode 报告最后一集
break % for
else
count=count+1; % set counter if deviation of q is small 如果Q的偏差很小,则设置计数器
end
else
q1=q;
count=0;
end
if(R(state,x1)==100)
y=randperm(30);
state=y(1);
pause(0.4);
else
state=x1;
end
%delete(A);
%delete(B);
%delete(C);
%delete(D);
end
toc
%normalization 归一化
g=max(max(q));
if g>0
q=100*q/g;
end
只有一个BUG百出的程序,改不过来