当前位置: 首页 > news >正文

强化学习第五课 —— TRPO 深度剖析:在黎曼流形上寻找最优步长的数学艺术

在深度强化学习的发展史上,TRPO (Trust Region Policy Optimization)占据着承前启后的核心地位。它是连接早期 REINFORCE(朴素策略梯度)与现代 PPO(近端策略优化)的桥梁。

很多人认为 TRPO 仅仅是一个“带约束的优化算法”,这严重低估了它的理论深度。TRPO 的本质,是一次从**欧氏空间(Euclidean Space)黎曼流形(Riemannian Manifold)**的思维跨越。它解决了一个根本性的问题:在参数空间Θ\ThetaΘ上移动多远,才等同于在策略空间Π\PiΠ上移动了安全的距离?

本文将从单调提升理论出发,通过拉格朗日对偶性、泰勒级数展开、自然梯度法以及共轭梯度下降,完整推导 TRPO 的每一个数学细节。


第一章:理论基石——性能差异与单调性保证

一切优化的前提是“不退步”。在强化学习中,策略更新往往牵一发而动全身,我们需要一个数学保证:新策略J(πnew)≥J(πold)J(\pi_{new}) \ge J(\pi_{old})J(πnew)J(πold)

1.1 性能差异引理 (The Performance Difference Lemma)

Kakade & Langford (2002) 提出了一个恒等式,量化了两个策略表现的差距。
J(π)J(\pi)J(π)为策略π\piπ的期望累积折扣回报,对于任意两个策略π\piππ~\tilde{\pi}π~,有:

J(π~)=J(π)+Eτ∼π~[∑t=0∞γtAπ(st,at)] J(\tilde{\pi}) = J(\pi) + \mathbb{E}_{\tau \sim \tilde{\pi}} \left[ \sum_{t=0}^{\infty} \gamma^t A_\pi(s_t, a_t) \right]J(π~)=J(π)+Eτπ~[t=0γtAπ(st,at)]

  • Aπ(s,a)=Qπ(s,a)−Vπ(s)A_\pi(s, a) = Q_\pi(s, a) - V_\pi(s)Aπ(s,a)=Qπ(s,a)Vπ(s)是旧策略的优势函数。
  • 核心洞察:如果新策略π~\tilde{\pi}π~在每一个状态sss都能选择出Aπ(s,a)>0A_\pi(s, a) > 0Aπ(s,a)>0的动作,那么J(π~)J(\tilde{\pi})J(π~)必然大于J(π)J(\pi)J(π)

1.2 替代目标函数 (Surrogate Objective)

上述公式中,期望是基于新策略π~\tilde{\pi}π~的轨迹τ\tauτ计算的,这在更新前是未知的。我们利用重要性采样,将状态分布近似为旧策略的分布ρπ\rho_\piρπ

Lπ(π~)=J(π)+∑sρπ(s)∑aπ~(a∣s)Aπ(s,a) L_\pi(\tilde{\pi}) = J(\pi) + \sum_{s} \rho_\pi(s) \sum_{a} \tilde{\pi}(a|s) A_\pi(s, a)Lπ(π~)=J(π)+sρπ(s)aπ~(as)Aπ(s,a)

在 TRPO 中,我们通常优化Lπ(π~)L_\pi(\tilde{\pi})Lπ(π~)的等价形式(忽略常数项J(π)J(\pi)J(π)):

max⁡θEs∼ρθold,a∼πθold[πθ(a∣s)πθold(a∣s)Aθold(s,a)] \max_{\theta} \mathbb{E}_{s \sim \rho_{\theta_{old}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\theta_{old}}(s,a) \right]θmaxEsρθold,aπθold[πθold(as)πθ(as)Aθold(s,a)]

1.3 误差边界与下界最大化 (MM Algorithm)

由于状态分布的近似ρπ~≈ρπ\rho_{\tilde{\pi}} \approx \rho_\piρπ~ρπ引入了误差,Schulman 证明了如下不等式:

J(π~)≥Lπ(π~)−C⋅DKLmax⁡(π,π~) J(\tilde{\pi}) \ge L_\pi(\tilde{\pi}) - C \cdot D_{KL}^{\max}(\pi, \tilde{\pi})J(π~)Lπ(π~)CDKLmax(π,π~)

其中C=4ϵγ(1−γ)2C = \frac{4\epsilon \gamma}{(1-\gamma)^2}C=(1γ)24ϵγ是常数,DKLmax⁡D_{KL}^{\max}DKLmax是状态空间上的最大 KL 散度。
这构成了一个Minorization-Maximization (MM)算法的基础:

  • Mi(π)=Lπi(π)−C⋅DKLmax⁡(πi,π)M_i(\pi) = L_{\pi_i}(\pi) - C \cdot D_{KL}^{\max}(\pi_i, \pi)Mi(π)=Lπi(π)CDKLmax(πi,π)J(π)J(\pi)J(π)下界函数
  • 最大化这个下界,就能保证真实目标J(π)J(\pi)J(π)的单调提升。

第二章:从理论到实践——信赖域约束的构建

理论上的惩罚系数CCC通常过大,导致步长极小,几乎无法训练。TRPO 将上述无约束的惩罚问题(Lagrangian form)转化为带约束的优化问题(Constrained form)。

2.1 优化问题的形式化

我们需要在满足 KL 散度约束的前提下,最大化替代目标:

max⁡θL(θ)=E[πθ(a∣s)πθold(a∣s)Aθold(s,a)]subject toDˉKL(πθold,πθ)≤δ \begin{aligned} \max_{\theta} \quad & L(\theta) = \mathbb{E} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\theta_{old}}(s,a) \right] \\ \text{subject to} \quad & \bar{D}_{KL}(\pi_{\theta_{old}}, \pi_\theta) \le \delta \end{aligned}θmaxsubject toL(θ)=E[πθold(as)πθ(as)Aθold(s,a)]DˉKL(πθold,πθ)δ

这里,DˉKL\bar{D}_{KL}DˉKL是所有状态下的平均 KL 散度,δ\deltaδ是信赖域半径(Trust Region Radius)。

2.2 为什么是 KL 散度?

这是一个极其关键的数学选择。
如果我们使用欧氏距离∥θ−θold∥2≤δ\|\theta - \theta_{old}\|^2 \le \deltaθθold2δ,会发生什么?

  • 参数空间与概率分布空间是不等价的。有些参数变化很小,却导致概率分布剧变;有些参数变化很大,概率分布却几乎不变。
  • KL 散度衡量的是分布之间的统计距离。它定义了一个黎曼流形,使得我们的步长具有协变性(Covariant)——无论我们将参数如何缩放或重参数化,只要分布不变,KL 散度就不变,更新轨迹也就不变。

第三章:数值求解——泰勒展开与自然梯度

上述约束优化问题是非线性的,难以直接求解。我们使用泰勒级数对其进行局部近似。

3.1 一阶与二阶泰勒近似

设更新量Δθ=θ−θold\Delta \theta = \theta - \theta_{old}Δθ=θθold

  1. 目标函数(一阶展开)
    L(θ)≈L(θold)+∇θL(θold)TΔθ L(\theta) \approx L(\theta_{old}) + \nabla_\theta L(\theta_{old})^T \Delta \thetaL(θ)L(θold)+θL(θold)TΔθ
    其中∇θL(θold)\nabla_\theta L(\theta_{old})θL(θold)即为常用的策略梯度,记为ggg

  2. 约束条件(二阶展开)
    由于DKL(θ,θ)=0D_{KL}(\theta, \theta) = 0DKL(θ,θ)=0,且 KL 散度在两分布相等处取得极小值,因此一阶导数为 0。我们展开到二阶:
    DˉKL(θold,θ)≈12ΔθTFΔθ \bar{D}_{KL}(\theta_{old}, \theta) \approx \frac{1}{2} \Delta \theta^T \mathbf{F} \Delta \thetaDˉKL(θold,θ)21ΔθTFΔθ
    其中F\mathbf{F}F费雪信息矩阵(Fisher Information Matrix, FIM),也就是 KL 散度的 Hessian 矩阵:
    F=Es,a[∇θlog⁡πθ(a∣s)∇θlog⁡πθ(a∣s)T] \mathbf{F} = \mathbb{E}_{s, a} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]F=Es,a[θlogπθ(as)θlogπθ(as)T]

3.2 近似问题的解析解

现在问题变成了标准的二次规划(Quadratic Programming):

max⁡ΔθgTΔθs.t.12ΔθTFΔθ≤δ \begin{aligned} \max_{\Delta \theta} \quad & g^T \Delta \theta \\ \text{s.t.} \quad & \frac{1}{2} \Delta \theta^T \mathbf{F} \Delta \theta \le \delta \end{aligned}Δθmaxs.t.gTΔθ21ΔθTFΔθδ

利用拉格朗日乘子法,构造拉格朗日函数:
L(Δθ,λ)=gTΔθ−λ(12ΔθTFΔθ−δ) \mathcal{L}(\Delta \theta, \lambda) = g^T \Delta \theta - \lambda \left( \frac{1}{2} \Delta \theta^T \mathbf{F} \Delta \theta - \delta \right)L(Δθ,λ)=gTΔθλ(21ΔθTFΔθδ)

Δθ\Delta \thetaΔθ求导并令其为 0:
g−λFΔθ=0 ⟹ Δθ=1λF−1g g - \lambda \mathbf{F} \Delta \theta = 0 \implies \Delta \theta = \frac{1}{\lambda} \mathbf{F}^{-1} ggλFΔθ=0Δθ=λ1F1g

这里的F−1g\mathbf{F}^{-1} gF1g就是大名鼎鼎的自然梯度(Natural Gradient)。它根据参数空间的局部曲率(F\mathbf{F}F)校正了梯度方向。

3.3 求解步长系数

我们还需要确定λ\lambdaλ(或者说步长大小)。将Δθ=1λF−1g\Delta \theta = \frac{1}{\lambda} \mathbf{F}^{-1} gΔθ=λ1F1g代入约束条件12ΔθTFΔθ=δ\frac{1}{2} \Delta \theta^T \mathbf{F} \Delta \theta = \delta21ΔθTFΔθ=δ

12(1λF−1g)TF(1λF−1g)=δ \frac{1}{2} \left( \frac{1}{\lambda} \mathbf{F}^{-1} g \right)^T \mathbf{F} \left( \frac{1}{\lambda} \mathbf{F}^{-1} g \right) = \delta21(λ1F1g)TF(λ1F1g)=δ
12λ2gTF−1FF−1g=δ \frac{1}{2\lambda^2} g^T \mathbf{F}^{-1} \mathbf{F} \mathbf{F}^{-1} g = \delta2λ21gTF1FF1g=δ
12λ2gTF−1g=δ \frac{1}{2\lambda^2} g^T \mathbf{F}^{-1} g = \delta2λ21gTF1g=δ

解得:
λ=gTF−1g2δ \lambda = \sqrt{\frac{g^T \mathbf{F}^{-1} g}{2\delta}}λ=2δgTF1g

因此,最终的更新向量为:
Δθ=2δgTF−1gF−1g \Delta \theta = \sqrt{\frac{2\delta}{g^T \mathbf{F}^{-1} g}} \mathbf{F}^{-1} gΔθ=gTF1g2δF1g


第四章:工程实现的艺术——共轭梯度与 HVP

理论推导很完美,但在深度学习中,参数量NNN可能高达数百万。
F\mathbf{F}F是一个N×NN \times NN×N的矩阵。计算并存储它需要O(N2)O(N^2)O(N2)空间,求逆F−1\mathbf{F}^{-1}F1需要O(N3)O(N^3)O(N3)时间。这在计算上是不可行的。

TRPO 使用共轭梯度法 (Conjugate Gradient, CG)来避开这个瓶颈。

4.1 将求逆转化为解方程

我们不需要显式求F−1\mathbf{F}^{-1}F1,我们只需要求向量x=F−1gx = \mathbf{F}^{-1} gx=F1g。这等价于求解线性方程组:
Fx=g \mathbf{F} x = gFx=g

由于F\mathbf{F}F是对称正定矩阵,CG 算法非常适合求解此类方程。

4.2 Hessian-Vector Product (HVP)

在 CG 迭代中,我们需要频繁计算矩阵与向量的乘积Fv\mathbf{F} vFv(其中vvv是 CG 算法中的搜索方向向量)。
Pearlmutter (1994) 提出的技巧告诉我们,计算 Hessian 与向量的乘积,不需要构建 Hessian 矩阵

回顾F\mathbf{F}F的定义,我们可以通过两次反向传播来计算Fv\mathbf{F} vFv

Fv=∇θ((∇θlog⁡πθ(a∣s)⋅v)T∇θlog⁡πθ(a∣s)) \mathbf{F} v = \nabla_\theta \left( (\nabla_\theta \log \pi_\theta(a|s) \cdot v)^T \nabla_\theta \log \pi_\theta(a|s) \right)Fv=θ((θlogπθ(as)v)Tθlogπθ(as))

但在实践中,我们通常使用 KL 散度的梯度形式更方便:
Fv=∇θ(∇θDKL(πθold∥πθ)⋅v) \mathbf{F} v = \nabla_\theta \left( \nabla_\theta D_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \cdot v \right)Fv=θ(θDKL(πθoldπθ)v)

具体操作步骤:

  1. 计算 KL 散度关于θ\thetaθ的梯度(一阶导)。
  2. 计算该梯度与向量vvv的点积(标量)。
  3. 对这个标量再求一次关于θ\thetaθ的梯度(二阶导信息)。

这样,我们就在O(N)O(N)O(N)的时间复杂度内算出了Fv\mathbf{F} vFv,使得 TRPO 在大模型上变得可行。


第五章:最后一道防线——回溯线性搜索 (Backtracking Line Search)

由于我们在第三章使用了泰勒展开,计算出的Δθ\Delta \thetaΔθ只是在局部是准确的。如果步子迈得太大,泰勒近似就会失效,导致 KL 散度越界或者目标函数下降。

TRPO 引入了线性搜索机制来确保单调性。

设搜索方向为d=F−1gd = \mathbf{F}^{-1} gd=F1g,最大步长为β=2δ/(dTFd)\beta = \sqrt{2\delta / (d^T \mathbf{F} d)}β=2δ/(dTFd)
我们尝试更新:
θnew=θold+αjβd \theta_{new} = \theta_{old} + \alpha^j \beta dθnew=θold+αjβd
其中α∈(0,1)\alpha \in (0, 1)α(0,1)是衰减率(例如 0.5),jjj从 0 开始增加。

我们接受θnew\theta_{new}θnew当且仅当满足以下两个条件:

  1. 目标提升条件L(θnew)−L(θold)>0L(\theta_{new}) - L(\theta_{old}) > 0L(θnew)L(θold)>0(或者满足一定的提升比例)。
  2. 信赖域约束条件DˉKL(θold,θnew)≤δ\bar{D}_{KL}(\theta_{old}, \theta_{new}) \le \deltaDˉKL(θold,θnew)δ

通过这种“先计算最佳方向,再小心翼翼试探”的策略,TRPO 实现了极其稳定的更新。


总结:TRPO 的数学美学

TRPO 的推导过程是一场数学盛宴:

  1. 性能差异引理出发,建立了单调提升的目标。
  2. 利用KL 散度构建了黎曼流形上的信赖域约束。
  3. 通过泰勒展开将非线性约束规划转化为二次规划。
  4. 引入费雪信息矩阵得到了自然梯度解。
  5. 使用共轭梯度法HVP 技巧解决了高维矩阵求逆难题。
  6. 最后用线性搜索弥补了近似误差。

每一个环节都环环相扣,逻辑严密。虽然后来的 PPO 通过 Clip 操作极大地简化了这一过程,但 PPO 之所以有效,正是因为它在试图模拟 TRPO 所定义的那个完美的“信赖域”。理解了 TRPO,你才真正触碰到了策略梯度算法的灵魂。

http://icebutterfly214.com/news/110756/

相关文章:

  • ollama下载gpt-oss-20b模型时常见问题及解决方案
  • ScienceDecrypting 完整教程:简单几步实现CAJViewer文档格式转换
  • Dubbo默认通信框架是什么?还有其他选择吗?
  • Transformer解码策略比较:Qwen-Image采用何种采样方法?
  • 58、FreeBSD系统的高级安全特性与远程连接安全
  • 大麦网智能抢票助手:告别黄牛票的终极方案
  • 鸿蒙+Flutter混合工程化:构建、依赖管理与持续集成实战
  • 明日方舟UI定制终极指南:打造专属游戏界面美化方案
  • Syncthing-Android终极教程:简单快速的私密文件同步完全指南
  • ComfyUI自定义节点开发:接入Qwen-Image-Edit-2509编辑功能
  • 微信小程序表格组件实战:从零到精通的数据展示方案
  • C++加速ACE-Step底层计算模块:提升音频生成实时性与稳定性
  • SumatraPDF:重新定义轻量级PDF阅读器的使用体验
  • m3u8-downloader桌面版:流媒体视频下载的终极解决方案
  • 终极NS模拟器管理神器:ns-emu-tools一站式使用指南
  • HuggingFace镜像网站之外的选择:Seed-Coder-8B-Base本地部署教程
  • 如何快速掌握ColorUI选项卡组件提升界面组织效率
  • 3步解锁MTK设备调试:从入门到精通实战指南
  • AI企业级智能体远不止聊天,一张图揭秘AI如何革新软件与业务
  • Dify平台创建音乐智能体:输入歌词即可由ACE-Step谱曲
  • 企业级微服务权限系统终极指南:RuoYi-Cloud-Plus完整解析
  • HunyuanVideo-Foley GitHub镜像加速下载方法(支持国内网络)
  • 9款AI写论文工具大PK:宏智树AI凭“真数据+全流程”杀疯了
  • 先相信,后看见:普通人「逆袭」的底层操作系统
  • Nginx负载均衡部署多个ACE-Step实例:应对大规模访问需求
  • Qwen3-32B深度评测:复杂逻辑推理能力超乎想象
  • 【瑞萨RA × Zephyr评测】SPI 屏 (SSD1306) + 双路 ADC
  • 原生 JavaScript 实战:手搓一个生产级 Toast 通知组件
  • 【详解】Hydra安装Libssh模块
  • 導出知乎收藏夾