别光看理论了!用PyTorch手把手实现一个Actor-Critic玩CartPole(附完整代码)

张开发
2026/4/19 16:55:10 15 分钟阅读

分享文章

别光看理论了!用PyTorch手把手实现一个Actor-Critic玩CartPole(附完整代码)
从零实现Actor-Critic用PyTorch征服CartPole的实战指南在强化学习领域理论推导和代码实现之间往往存在巨大的鸿沟。许多学习者能够理解策略梯度定理的数学证明却在面对具体实现时束手无策。本文将带你跨越这道鸿沟使用PyTorch从零开始构建一个完整的Actor-Critic算法并在经典的CartPole环境中验证其效果。1. 环境搭建与核心概念CartPole倒立摆是强化学习中最经典的测试环境之一。游戏目标是通过左右移动小车来保持顶部的杆子竖直不倒。这个看似简单的任务包含了强化学习的核心挑战如何在连续的状态空间中进行决策并通过稀疏的奖励信号来优化策略。关键组件准备import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import matplotlib.pyplot as plt env gym.make(CartPole-v1) state_dim env.observation_space.shape[0] action_dim env.action_space.nActor-Critic架构巧妙结合了策略梯度Actor和价值函数Critic的优点Actor策略网络负责根据当前状态选择动作Critic价值网络评估当前状态-动作对的质量提示CartPole-v1环境中杆子保持直立每步获得1奖励最大步数为500。相比v0版本v1的杆子更长控制难度更高。2. 网络架构设计2.1 Actor网络实现Actor网络输出的是动作的概率分布。对于CartPole这样的离散动作空间我们使用softmax输出class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size128): super(Actor, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, hidden_size) self.fc3 nn.Linear(hidden_size, action_dim) def forward(self, state): x torch.relu(self.fc1(state)) x torch.relu(self.fc2(x)) x torch.softmax(self.fc3(x), dim-1) return x2.2 Critic网络实现Critic网络评估状态-动作对的价值指导Actor的更新方向class Critic(nn.Module): def __init__(self, state_dim, hidden_size128): super(Critic, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, hidden_size) self.fc3 nn.Linear(hidden_size, 1) def forward(self, state): x torch.relu(self.fc1(state)) x torch.relu(self.fc2(x)) value self.fc3(x) return value网络设计要点对比组件输入维度输出维度激活函数作用Actor状态维度动作维度Softmax生成动作概率Critic状态维度1无评估状态价值3. 训练流程实现3.1 数据收集与预处理我们采用在线更新的方式实时收集轨迹数据def collect_trajectory(env, actor, max_steps1000): states, actions, rewards, next_states, dones [], [], [], [], [] state env.reset() for _ in range(max_steps): state_tensor torch.FloatTensor(state).unsqueeze(0) action_probs actor(state_tensor) action torch.multinomial(action_probs, 1).item() next_state, reward, done, _ env.step(action) states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state next_state if done: break return states, actions, rewards, next_states, dones3.2 核心训练循环完整的Actor-Critic更新包含三个关键步骤计算TD误差更新Critic网络更新Actor网络def train(env, actor, critic, actor_optimizer, critic_optimizer, gamma0.99, epochs1000): reward_history [] for epoch in range(epochs): # 收集数据 states, actions, rewards, next_states, dones collect_trajectory(env, actor) # 转换为张量 states torch.FloatTensor(states) actions torch.LongTensor(actions).unsqueeze(1) rewards torch.FloatTensor(rewards).unsqueeze(1) next_states torch.FloatTensor(next_states) dones torch.FloatTensor(dones).unsqueeze(1) # 计算TD目标 with torch.no_grad(): next_values critic(next_states) td_targets rewards gamma * next_values * (1 - dones) # 更新Critic values critic(states) critic_loss nn.MSELoss()(values, td_targets) critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() # 更新Actor action_probs actor(states) selected_probs action_probs.gather(1, actions) advantages td_targets - values.detach() actor_loss -(torch.log(selected_probs) * advantages).mean() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # 记录结果 total_reward sum(rewards) reward_history.append(total_reward) if epoch % 50 0: print(fEpoch {epoch}, Reward: {total_reward}) return reward_history注意advantages的计算使用了Critic网络的评估值但通过detach()切断了梯度传播避免影响Critic的训练。4. 超参数调优与训练技巧4.1 关键超参数设置经过多次实验验证以下参数组合在CartPole环境中表现良好hyperparams { hidden_size: 64, # 网络隐藏层维度 gamma: 0.99, # 折扣因子 actor_lr: 0.001, # Actor学习率 critic_lr: 0.005, # Critic学习率(通常设置更大) epochs: 500 # 训练轮数 }学习率设置原理Critic需要更快收敛以提供准确的评估Actor更新步长应较小以保证策略稳定改进4.2 训练过程监控实现实时渲染和奖励曲线绘制直观观察训练进展def plot_rewards(reward_history, window_size50): moving_avg [] for i in range(len(reward_history) - window_size 1): window reward_history[i:iwindow_size] moving_avg.append(sum(window) / window_size) plt.plot(reward_history, alpha0.3, labelRaw) plt.plot(moving_avg, labelfMoving Avg ({window_size} eps)) plt.xlabel(Episode) plt.ylabel(Total Reward) plt.legend() plt.show()4.3 常见问题排查训练不稳定问题解决方案奖励不增长检查网络结构是否足够复杂尝试增大折扣因子gamma调整学习率组合策略过早收敛引入熵正则项鼓励探索在损失函数中添加entropy (probs * torch.log(probs)).sum(1).mean()Critic估值偏差使用目标网络稳定训练实现经验回放缓冲需注意同策略限制# 熵正则化示例 def actor_loss_with_entropy(probs, actions, advantages, beta0.01): selected_probs probs.gather(1, actions) policy_loss -(torch.log(selected_probs) * advantages).mean() entropy (probs * torch.log(probs)).sum(1).mean() return policy_loss - beta * entropy5. 进阶优化与扩展5.1 目标网络实现为Critic添加目标网络可以显著提升训练稳定性class ActorCritic: def __init__(self, state_dim, action_dim): self.actor Actor(state_dim, action_dim) self.critic Critic(state_dim) self.target_critic Critic(state_dim) self.update_target(tau1.0) # 初始完全同步 def update_target(self, tau0.01): 软更新目标网络 for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()): target_param.data.copy_(tau*param.data (1-tau)*target_param.data)5.2 多步TD学习单步TD更新可能引入较大偏差实现n步TD改进def compute_nstep_td(rewards, next_values, dones, gamma, n_step5): n_step_returns [] R 0 for i in reversed(range(len(rewards))): R rewards[i] gamma * R * (1 - dones[i]) if i n_step len(rewards): R R - (gamma**n_step) * rewards[in_step] n_step_returns.insert(0, R) returns torch.FloatTensor(n_step_returns).unsqueeze(1) with torch.no_grad(): next_n_values next_values[-len(returns):] td_targets returns (gamma**n_step) * next_n_values * (1 - dones[-len(returns):]) return td_targets5.3 并行环境采样加速数据收集的终极方案是使用多环境并行采样from multiprocessing import Process, Queue def worker(env_name, actor, queue, num_episodes): env gym.make(env_name) for _ in range(num_episodes): states, actions, rewards, next_states, dones collect_trajectory(env, actor) queue.put((states, actions, rewards, next_states, dones)) queue.put(None) # 结束信号 def parallel_collect(env_name, actor, num_workers4, episodes_per_worker2): queue Queue() workers [] for _ in range(num_workers): p Process(targetworker, args(env_name, actor, queue, episodes_per_worker)) p.start() workers.append(p) trajectories [] finished_workers 0 while finished_workers num_workers: data queue.get() if data is None: finished_workers 1 else: trajectories.append(data) for p in workers: p.join() return trajectories在实际项目中我发现并行采样能显著提升训练效率特别是在环境交互耗时较长的情况下。一个实用的技巧是将不同worker的探索率设置为略有差异的值这样可以增加样本的多样性。

更多文章