clear all;
run('CartPoleInitializer');
run('AgentInitializer');
%训练周期设置:观察期,训练期
N_obs=300;
N_train=3000;
N_total=N_obs+N_train;
T_episode=60; %每个周期的总时间
%数据记录初始化,状态记录在Episode开始前进行
TimeRecord=zeros(1,N_total);
AveTimeRecord=zeros(2,N_total/10);
ATRpointer=1;
%episode及其余设置
T_step=0.1;
n_step=1;
%动态绘图初始化
Plotset=zeros(2,1);
p = plot(Plotset(1,:),Plotset(2,:),...
'EraseMode','background','MarkerSize',5);
axis([0 N_obs+N_train 0 60]);
for Ns=1:N_total
CPstate=CartPoleReset(); %CPstate为4*1矩阵,x,dotx,theta,dottheta
T1=0;
%初始化历史记录。记录下列内容:1.控制时长;2.每个episode的600个状态与控制力矩
TrackPointer=1;
TrackRecord(Ns).Track=zeros(6,T_episode/T_step);
while T1<=T_episode
%根据tcegreedy策略选择动作
[act,Qnow]=tcegreedy(Ns,CPstate,QNet_eval);
Fc=FcTable(act);
%使用Ode45执行动作
OdeInput=[CPstate;Fc];
[t,y]=ode45(@CartPole_Eqs,[0,T_step],OdeInput,opts);
Nsize=size(y); Nsize=Nsize(1);
Newstate=y(Nsize,1:4); Newstate=Newstate';
%Replaymemory记录+指针更新
Rmemo(:,Memopointer)=[CPstate;act;Newstate];
Memopointer=PointerMove(Memopointer,S_memo);
%轨迹数据记录更新
TrackRecord(Ns).Track(:,TrackPointer)=[T1;CPstate;Fc];
TrackPointer=TrackPointer+1;
%更新状态
T1=T1+T_step;
n_step=n_step+1;
CPstate=Newstate;
%按照T-renew间隔更新估计Q_target的目标神经网络QNet_target
if (mod(n_step,N_renew)==0)&&(Ns>=N_obs)
QNet_target=QNet_eval;
end
%按照T_gap的间隔训练估计Q_eval的评估神经网络QNet_eval
if (mod(n_step,N_gap)==0)&&(Ns>=N_obs)
%1. 利用Rmemo生成训练数据级
Trainset=zeros(10,nBatch); %前9行与replaymemory一致,后一行为利用QNet_target计算得到的Q_target;
i=1;
while i<=nBatch
num1=unidrnd(S_memo);
if Rmemo(5,num1)>0 %有记录的第五行始终不为零
Trainset(1:9,i)=Rmemo(:,num1);
i=i+1;
end
end
%2. 计算Q_target
Trainset(10,:)=CalculationQtarget(Trainset(1:9,:),QNet_target);
%3. 训练QNet_eval
QNet_eval=train(QNet_eval,Trainset(1:5,:),Trainset(10,:));
end
%判断是否跳出本episode,并记录控制时长
if (abs(CPstate(1))>X_threshold)||(abs(CPstate(3))>Theta_threshold)
TimeRecord(Ns)=T1;
break;
elseif T1>=T_episode
TimeRecord(Ns)=T1;
break;
end
end
%动态绘图
if mod(Ns,10)==0
Ave1=mean(TimeRecord(Ns-9:Ns));
AveTimeRecord(:,ATRpointer)=[Ns;Ave1]';
ATRpointer=ATRpointer+1;
TempP=[Ns;Ave1];
Plotset=[Plotset,TempP];
set(p,'XData',Plotset(1,:),'YData',Plotset(2,:));
drawnow
axis([0 N_obs+N_train 0 60]);
end
end