Xiaohei's Blog
headpicBlur image

前言#

第 1 篇我们把“回合训练骨架”固定了,并用 Q-Learning/Sarsa 验证:很多算法的工程结构可以复用。

这篇开始进入深度强化学习:DQN(Deep Q-Network)。它的本质很简单:

  • 用神经网络近似 Q(s,a)Q(s,a),替代原来的 Q-table
  • 为了让训练稳定、样本利用率更高,引入两件关键装备:
    • 经验回放(Replay Buffer)
    • 目标网络(Target Network)

我下面会按我自己做实验时的“演化路线”把它们串起来:先把最朴素的 DQN 跑通,再逐个替换模块,最后得到一套更稳、更省心的版本(Double / Dueling / Noisy / PER)。你会发现这些改进并没有把 DQN 变得面目全非,它们大多只是在某个关键环节多加了一块垫片:要么让 target 更可信,要么让探索更自然,要么让 replay 更“记重点”。

我自己真正开始“理解 DQN”是从一次失败开始的:代码能跑、loss 也在降,但 reward 纹丝不动;甚至有时候 reward 还会先上去一小段,然后突然崩得一干二净。后来我才意识到,DQN 的难点从来不在“会不会写反传”,而在“你的训练信号到底稳不稳”。如果 target 网络没更新好、replay 的采样没有打散相关性、或者奖励尺度和学习率不匹配,你得到的梯度就像在噪声里摸黑,越走越偏。

所以这篇我会尽量用工程语言把它讲清楚:DQN 为什么要 Replay、为什么要 Target、update 的 target 到底在干什么,以及你可以按什么顺序迭代,把一个“能跑的 DQN”逐步打磨成一个“训练稳定的 DQN”。

DQN:相对 Q-Learning 改了什么?#

如果把 Q-Learning 看作“边走边把 Q 表格填完整”,那 DQN 就是“我不填表了,我训练一个函数去拟合这张表”。为了让这个函数(神经网络)训练得稳定,我一般会把 DQN 的变化拆成三件事来记:

  1. 用网络替代表Q(s,a;θ)Q(s,a;\theta),输入是 state,输出是每个动作的 Q 值。
  2. Replay Buffer:把交互数据存起来,更新时随机采样一批(batch),提升样本效率并打破相关性。
  3. 策略网络 + 目标网络:用 θ\theta 在线更新策略网络,用 θ\theta^- 的目标网络去计算 target,定期拷贝参数来稳住训练。

一个可复用的 DQN Agent 接口#

和第 1 篇一致:

  • sample(state):训练用(通常 epsilon-greedy)
  • predict(state):测试用(argmax)
  • update():从 Replay Buffer 采样 batch 更新网络

注意:从 DQN 开始,update() 往往不再接受单条 transition,而是:

  • push(transition) 到 replay
  • 再在合适时机(buffer 足够、间隔到达)调用 update()

我个人很推荐把 “push 交互数据” 和 “update 如果条件满足就更新” 分开写。因为你后面无论是加 PER(采样变了)、还是加 Noisy(探索变了)、还是加 Double(target 算法变了),这一层的训练主循环都不需要大改,改动会被限制在 buffer 或 update 的内部。

Replay Buffer:push + sample 就够用#

我一般把经验回放(Replay Buffer)的实现压缩成两个方法:

  • push:按顺序存 transition,满了就挤掉最旧的
  • sample:随机采样出一个 batch

最小结构通常是这样的:

import random
from collections import deque


class ReplayBuffer:
	def __init__(self, capacity: int):
		self.buffer = deque(maxlen=capacity)

	def push(self, transition):
		self.buffer.append(transition)

	def sample(self, batch_size: int):
		batch = random.sample(self.buffer, batch_size)
		# 真实代码会在这里做 unzip + tensor 化
		return batch

	def __len__(self):
		return len(self.buffer)
python

DQN 的 update:损失怎么来的?#

我第一次实现 DQN 时,最容易卡住的点其实不是反传,而是“target 到底应该怎么算”。把它想清楚后,整个 update 就很机械:拿一批数据算出 target,再让网络的输出去贴近这个 target。

在最基础的 DQN 里,损失就是“期望值 yiy_i”和“实际值 Q(si,ai;θ)Q(s_i,a_i;\theta)”的均方差。

target(期望值)#

基础 DQN 常用:

yi=ri+γmaxaQ(si,a;θ)y_i = r_i + \gamma \max_{a'} Q(s'_i, a'; \theta^-)

并且要处理终止状态:如果 terminated==True,没有下一个状态,就直接 yi=riy_i=r_i

loss(均方差)#

L(θ)=1Ni(yiQ(si,ai;θ))2\mathcal{L}(\theta)=\frac{1}{N}\sum_i (y_i - Q(s_i,a_i;\theta))^2

然后照常:定义 optimizer,loss.backward()optimizer.step()

Dueling DQN:把 Q 网络拆成 V + A#

有些环境里(尤其状态复杂、但动作影响没那么明显的阶段),我会发现“学哪个动作更好”这件事很难,反而“这个状态整体值不值得继续待”更重要。Dueling 的直觉就是把这两件事分开学:先学状态价值 V(s)V(s),再学动作相对优势 A(s,a)A(s,a)

  • Value:估计状态价值 V(s)V(s)
  • Advantage:估计每个动作相对优势 A(s,a)A(s,a)

然后组合得到 Q(s,a)Q(s,a)

工程上你只需要关注两点:

  1. 网络 forward 输出从“直接输出 Q”变成“输出 V 和 A,再合成 Q”
  2. 其它(Replay、target、update)基本不变

Double DQN:缓解过估计#

基础 DQN 的一个老毛病是过估计:因为我们用同一个网络(或同一个估计过程)既“选最大动作”,又“评估这个最大动作的值”,这很容易把噪声也当成真相。

Double DQN 的关键改动是:

  • 用策略网络选动作(argmax)
  • 用目标网络评估该动作的价值

一句话的工程翻译:只改 target 的计算方式,其它不动。

Noisy DQN:用可学习噪声做探索#

Noisy DQN 的重点在“模型定义”:

  • Linear 层里引入 mu/sigma 参数
  • 每次 forward 都注入噪声(训练时),并能 reset

它的好处是:在很多任务里,比手动调 epsilon 更省心。

你可以把它理解为:

不再由外部策略(epsilon-greedy)给动作加随机性,而是让网络自己学会在哪些状态该更“抖”。

PER-DQN:优先经验回放(SumTree)#

PER 的工程要点我通常拆成两大块:

  1. SumTree:用 O(logn)O(\log n) 管理样本优先级,并按优先级采样
  2. Importance Sampling:用权重修正“非均匀采样”带来的偏差

实现上可以拆成三个类:

  • SumTree:维护二叉树和优先级更新
  • ReplayTree:基于 SumTree 的 replay buffer(push、sample、batch_update)
  • PERDQNAgent:update 后把 TD-error 回写到 replay 里更新优先度

如果你想先做一个“最小可用”的 PER,我的建议是:别一上来就写得太花。把 SumTree 写对、把 batch_update 的优先级回写逻辑写对,就能看到明显收益。之后再去补重要性采样权重、再去做 beta 的退火(anneal),会更顺畅。

小结#

这一篇把 DQN 家族的“程序结构”串了起来:

  • DQN 的稳定性来自 Replay + Target
  • Double/Dueling 改的是 target 或结构
  • Noisy 改的是探索的实现方式
  • PER 改的是 replay 的采样分布

下一篇我会切到策略梯度与 Actor-Critic:REINFORCE、PPO、A2C —— 你会发现它们的核心接口仍然可以复用,只是 update() 里优化的对象从 QQ 变成了 πθ\pi_\theta

我常用的 DQN 调参/调试清单(建议按顺序来)#

如果你发现 DQN “不收敛 / 忽好忽坏 / Q 值爆炸”,我一般按下面顺序排查,基本不会走弯路:

  1. 先看 reward 尺度,再定学习率:奖励如果在 [0,1][0,1],学习率可以大胆一点;奖励如果动辄几十上百,先考虑 reward clipping 或把学习率降两档。
  2. 把 Q 值范围打出来:每隔 N 个 episode 打印一次 q_mean/q_max(以 batch 为单位)。如果 Q 值从几十飙到几万,通常是 target 计算/终止状态处理/学习率出问题。
  3. Replay 先保证“足够大 + 随机采样”:buffer 太小就更新,训练非常不稳;我一般会设一个 min_buffer_size,不够就只 push 不 update。
  4. Target network 更新频率别太激进:更新太频繁等于没用 target,更新太慢又会学得很慢。一个常用起点:每 500~2000 个 env step 同步一次(具体随环境而变)。
  5. epsilon 衰减别太快:你看到 reward 前期抖动、很快停滞,优先怀疑探索不足;宁愿衰减慢一点,也不要“自信地随机”。
  6. batch_size、gamma、update 频率要成套:batch 太小梯度噪声大;gamma 太大又没配足够长的 episode/steps,容易学不到。
  7. 先跑基础版再加花活:Double/Dueling/Noisy/PER 都是锦上添花。基础 DQN 如果你连 “target 正确 + done 处理正确 + replay 正常” 都没确认,上改进往往只会更乱。
强化学习算法程序实践(2):DQN 及其改进
https://xiaohei94.github.io/blog/rl-algorithm-2
Author 红鼻子小黑
Published at May 1, 2025
Comment seems to stuck. Try to refresh?✨