see: https://hrl.boyuai.com/chapter/2/actor-critic%E7%AE%97%E6%B3%95
-
上一章用 代替 ,现在用时序差分残差公式代替之.
- 因为 .
- 所以训练一个 网路就行
-
原文已经写的很像回忆提纲了
-
训练一个价值网络:
- Input : 可微状态
- Output :
- Loss:
- 其中 不参与梯度计算. 代码中使用
detach()直接实现,不用双网络. - 和 DQN 一样训练数据来源于采样池.
- 训练过程和 Actor 的关系?Actor 产生了采样池,Actor 变强后采样分布会变化
- 采样数据为 .
- 注意采样的 并不影响 梯度下降,乱采样也能训练出正确的 网络
- 其中 不参与梯度计算. 代码中使用
先来看 Actor + Critic 包装器的 update
1class ActorCritic:2 # self.critic = ValueNet(state_dim, hidden_dim).to(device) # 价值网络3 ...4 def update(self, transition_dict):5 states = torch.tensor(transition_dict['states'],6 dtype=torch.float).to(self.device)7 actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(8 self.device)9 rewards = torch.tensor(transition_dict['rewards'],10 dtype=torch.float).view(-1, 1).to(self.device)11 next_states = torch.tensor(transition_dict['next_states'],12 dtype=torch.float).to(self.device)13 dones = torch.tensor(transition_dict['dones'],14 dtype=torch.float).view(-1, 1).to(self.device)15
40 collapsed lines
16 # 时序差分目标17 td_target = rewards + self.gamma * self.critic(next_states) * (1 -18 dones)19 # 时序差分差. 即:当前动作带来的额外奖励期望(在 critic 精准的情况下)20 td_delta = td_target - self.critic(states)21 log_probs = torch.log(self.actor(states).gather(1, actions))22 actor_loss = torch.mean(-log_probs * td_delta.detach())23 # 均方误差损失函数,这里直接 detach() 来实现类似 Double DQN 的效果... (不演了是吧)24 critic_loss = torch.mean(25 F.mse_loss(self.critic(states), td_target.detach()))26 self.actor_optimizer.zero_grad()27 self.critic_optimizer.zero_grad()28 actor_loss.backward() # 计算策略网络的梯度29 critic_loss.backward() # 计算价值网络的梯度30 self.actor_optimizer.step() # 更新策略网络的参数31 self.critic_optimizer.step() # 更新价值网络的参数32
33class PolicyNet(torch.nn.Module):34 def __init__(self, state_dim, hidden_dim, action_dim):35 super(PolicyNet, self).__init__()36 self.fc1 = torch.nn.Linear(state_dim, hidden_dim)37 self.fc2 = torch.nn.Linear(hidden_dim, action_dim)38
39 # 输入状态 states: [batch_size, state_dim]40 # 输出动作概率分布41 def forward(self, x):42 x = F.relu(self.fc1(x))43 return F.softmax(self.fc2(x), dim=1)44
45class ValueNet(torch.nn.Module):46 def __init__(self, state_dim, hidden_dim):47 super(ValueNet, self).__init__()48 self.fc1 = torch.nn.Linear(state_dim, hidden_dim)49 self.fc2 = torch.nn.Linear(hidden_dim, 1)50
51 # 输入状态 states: [batch_size, state_dim]52 # 输出状态价值 V(s): [batch_size, 1]53 def forward(self, x):54 x = F.relu(self.fc1(x))55 return self.fc2(x)- 效果:抖动比基于蒙特卡洛的 REINFORCE 收敛更快,且非常稳定.