<?xml version="1.0" encoding="UTF-8"?><rss version="2.0" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>CHENHUI · 格物</title><description>No description</description><link>https://xuchenhui.cc/</link><language>zh_CN</language><item><title>推荐与推广</title><link>https://xuchenhui.cc/posts/2026-05-06-recommendations/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-06-recommendations/</guid><description>实用工具推荐 — 机场服务与大模型 API 中转站</description><pubDate>Wed, 06 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h3&gt;一、机场服务推荐&lt;/h3&gt;
&lt;blockquote&gt;
&lt;p&gt;现在环境越来越紧, 免费节点不仅慢、不稳定, 还容易被 AI 风控封禁。想省心用 ChatGPT / Claude, &lt;strong&gt;付费稳定线路是唯一选择&lt;/strong&gt;。下面这个机场我从建站起就在用, 已经稳定 &lt;strong&gt;4 年+&lt;/strong&gt;, 从未掉线失联, 诚心推荐。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;&lt;a href=&quot;https://www.xmrth.lol/auth/register?code=nw2q&quot;&gt;XMRth 机场&lt;/a&gt;&lt;/h4&gt;
&lt;p&gt;| 套餐                 | 月付 | 季付 | 半年付 | 年付 |
| -----------------高--- | :--: | :--: | :----: | :--: |
| V5 中继版（推荐 🟢） | ¥16  | ¥48  |  ¥96   | ¥192 |
| V8 专线版            | ¥25  | ¥75  |  ¥150  | ¥300 |&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;优惠码：&lt;code&gt;XMRTH&lt;/code&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;p&gt;&lt;strong&gt;核心优势：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;稳定可靠&lt;/strong&gt; — 建站 4 年+, 从未掉线失联, AI 使用从未被风控&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解锁全系 AI&lt;/strong&gt; — ChatGPT / Claude / Gemini 等, 美港节点 + IPLC 专线均可访问&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;IPLC 专线&lt;/strong&gt; — 动态独立 IP, AI 体验更好、响应更快, 消耗中继流量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;流媒体通杀&lt;/strong&gt; — Netflix、YouTube、Disney+ 全线解锁&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;300~1000Mbps&lt;/strong&gt; 速率, 最多 8 设备同时在线&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;👉 &lt;a href=&quot;https://www.xmrth.lol/auth/register?code=nw2q&quot;&gt;&lt;strong&gt;点此注册&lt;/strong&gt;&lt;/a&gt; ｜ 稳定四年, 值得信赖。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h3&gt;二、大模型 API 中转站&lt;/h3&gt;
&lt;blockquote&gt;
&lt;p&gt;国内直接调用 Claude / GPT / Gemini 官方 API, 要么需要翻墙, 要么需要海外信用卡, 要么担心账号被封。PackyAPI 把这些问题全解决了：&lt;strong&gt;无需翻墙、无需海外卡、人民币直接充值&lt;/strong&gt;, 是目前我用过最省心的 API 中转方案。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;&lt;a href=&quot;https://www.packyapi.com/register?aff=sIjX&quot;&gt;PackyAPI&lt;/a&gt;&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;核心优势：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;费用极低&lt;/strong&gt; — 充值比例 &lt;strong&gt;1 元人民币 = 1 美元额度&lt;/strong&gt;, 直接对标 Claude / OpenAI 官网定价, 省去汇损和手续费；部分分组最低可享 &lt;strong&gt;0.2 折&lt;/strong&gt;优惠&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;稳定可靠&lt;/strong&gt; — 多渠道智能兜底切换, 支持 Claude Code / Codex / Gemini CLI 等主流工具, 社区口碑稳定性第一梯队&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型覆盖广&lt;/strong&gt; — 超过 27 种模型分组, 涵盖 Claude（官方 / AWS / AWS-Q）、OpenAI Codex、Gemini、Azure GPT 等&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;文档完善&lt;/strong&gt; — 官方教程覆盖 Claude Code、Codex、Gemini CLI 完整配置步骤, 按教程走基本不会出错&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;新用户福利&lt;/strong&gt; — 注册即送 &lt;strong&gt;$1 余额&lt;/strong&gt;, 首次充值享 &lt;strong&gt;9 折优惠&lt;/strong&gt;, 试错成本几乎为零&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;新用户优惠：&lt;/strong&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;👉 &lt;a href=&quot;https://www.packyapi.com/register?aff=sIjX&quot;&gt;&lt;strong&gt;点此注册&lt;/strong&gt;&lt;/a&gt; ｜ 推广码：&lt;strong&gt;&lt;code&gt;sIjX&lt;/code&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;注册即送 $1 体验额度, 首充填入优惠码享 9 折。&lt;/p&gt;
&lt;/blockquote&gt;
</content:encoded></item><item><title>大模型量化入门 — 从&quot;最小化误差&quot;出发，一步步推导量化公式</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-quantization-guide/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-quantization-guide/</guid><description>不直接给公式，而是从&quot;我想把 16 位浮点数压缩到 4 位整数，同时尽量少损失精度&quot;这个目标出发，一步步推出 scale、zero_point、GPTQ、AWQ 的量化方案。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;你有没有好奇过, 为什么一个 7B 参数的模型要占 14GB 显存?&lt;/p&gt;
&lt;p&gt;但在回答之前, 我想先确认一个最基本的问题: &lt;strong&gt;&quot;7B&quot; 和 &quot;14GB&quot; 背后, 到底是怎么算出来的?&lt;/strong&gt;&lt;/p&gt;
&lt;h3&gt;0.1 先聊点最基础的: bit 和 byte&lt;/h3&gt;
&lt;p&gt;计算机里最小单位是一个 &lt;strong&gt;bit&lt;/strong&gt; (比特), 它只能表示 0 或 1 两种状态.&lt;/p&gt;
&lt;p&gt;但 1 个 bit 能存的数字太少了, 所以通常 8 个 bit 组成一个 &lt;strong&gt;byte&lt;/strong&gt; (字节):&lt;/p&gt;
&lt;p&gt;$$
1 \ \text{byte} = 8 \ \text{bits}
$$&lt;/p&gt;
&lt;p&gt;有了 byte 以后, 我们习惯用 byte 来计量内存/显存:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;1 KB = 1024 bytes&lt;/li&gt;
&lt;li&gt;1 MB = 1024 KB&lt;/li&gt;
&lt;li&gt;1 GB = 1024 MB&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;那一个数字要占多少个 byte 呢? 取决于我们用多少 bit 来表示它.&lt;/p&gt;
&lt;h3&gt;0.2 n 个 bit 能表示多少种状态?&lt;/h3&gt;
&lt;p&gt;对于整数来说: &lt;strong&gt;n 个 bit 可以表示 $2^n$ 种不同的状态&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;比如:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;bit 数&lt;/th&gt;
&lt;th&gt;能表示的状态数&lt;/th&gt;
&lt;th&gt;取值范围 (无符号)&lt;/th&gt;
&lt;th&gt;取值范围 (有符号)&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;1 bit&lt;/td&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;0 ~ 1&lt;/td&gt;
&lt;td&gt;-1 ~ 0&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;4 bit&lt;/td&gt;
&lt;td&gt;$2^4 = 16$&lt;/td&gt;
&lt;td&gt;0 ~ 15&lt;/td&gt;
&lt;td&gt;-8 ~ 7&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;8 bit&lt;/td&gt;
&lt;td&gt;$2^8 = 256$&lt;/td&gt;
&lt;td&gt;0 ~ 255&lt;/td&gt;
&lt;td&gt;-128 ~ 127&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;16 bit&lt;/td&gt;
&lt;td&gt;$2^{16} = 65536$&lt;/td&gt;
&lt;td&gt;0 ~ 65535&lt;/td&gt;
&lt;td&gt;-32768 ~ 32767&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;所以当我说&quot;一个 INT4 数&quot;时, 意思是它用 4 个 bit 来存, 只能取 16 种不同的值. 有符号时范围是 $-8$ 到 $7$.&lt;/p&gt;
&lt;h3&gt;0.3 二进制怎么转成十进制?&lt;/h3&gt;
&lt;p&gt;比如 4-bit 二进制数 &lt;code&gt;1101&lt;/code&gt;:&lt;/p&gt;
&lt;p&gt;$$
1 \times 2^3 + 1 \times 2^2 + 0 \times 2^1 + 1 \times 2^0 = 8 + 4 + 0 + 1 = 13
$$&lt;/p&gt;
&lt;p&gt;如果是&lt;strong&gt;有符号&lt;/strong&gt;数 (用&lt;strong&gt;补码&lt;/strong&gt;表示), 最高位是符号位:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;最高位 = 0: 正数, 直接算&lt;/li&gt;
&lt;li&gt;最高位 = 1: 负数, 先取反再加 1, 再算&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这就是为什么 INT4 的范围是 $-8$ (二进制 &lt;code&gt;1000&lt;/code&gt;) 到 $7$ (二进制 &lt;code&gt;0111&lt;/code&gt;).&lt;/p&gt;
&lt;h3&gt;0.4 那浮点数呢?&lt;/h3&gt;
&lt;p&gt;深度学习里用的不是整数, 而是&lt;strong&gt;浮点数&lt;/strong&gt;. 浮点数用科学计数法来存:&lt;/p&gt;
&lt;p&gt;$$
\text{数值} = (-1)^{\text{符号}} \times \text{尾数} \times 2^{\text{指数}}
$$&lt;/p&gt;
&lt;p&gt;以 FP32 (32 位浮点数) 为例, 它由 3 部分组成:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;部分&lt;/th&gt;
&lt;th&gt;位数&lt;/th&gt;
&lt;th&gt;含义&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;符号位 (sign)&lt;/td&gt;
&lt;td&gt;1 bit&lt;/td&gt;
&lt;td&gt;0 为正, 1 为负&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;指数位 (exponent)&lt;/td&gt;
&lt;td&gt;8 bits&lt;/td&gt;
&lt;td&gt;决定数值范围, 偏置 127&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;尾数位 (mantissa)&lt;/td&gt;
&lt;td&gt;23 bits&lt;/td&gt;
&lt;td&gt;决定数值精度&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;所以 FP32 能表示大约 $\pm 3.4 \times 10^{38}$, 精度约 7 位有效数字. BF16 保留了 FP32 的 8 位指数 (因此范围一样大), 但尾数砍到 7 位 (精度更低但范围不变).&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;格式&lt;/th&gt;
&lt;th&gt;总 bit&lt;/th&gt;
&lt;th&gt;指数 bit&lt;/th&gt;
&lt;th&gt;尾数 bit&lt;/th&gt;
&lt;th&gt;字节数&lt;/th&gt;
&lt;th&gt;数值范围&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;FP32&lt;/td&gt;
&lt;td&gt;32&lt;/td&gt;
&lt;td&gt;8&lt;/td&gt;
&lt;td&gt;23&lt;/td&gt;
&lt;td&gt;4 bytes&lt;/td&gt;
&lt;td&gt;$\pm 3.4 \times 10^{38}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;FP16&lt;/td&gt;
&lt;td&gt;16&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;2 bytes&lt;/td&gt;
&lt;td&gt;$\pm 6.5 \times 10^4$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;BF16&lt;/td&gt;
&lt;td&gt;16&lt;/td&gt;
&lt;td&gt;8&lt;/td&gt;
&lt;td&gt;7&lt;/td&gt;
&lt;td&gt;2 bytes&lt;/td&gt;
&lt;td&gt;$\pm 3.4 \times 10^{38}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;INT8&lt;/td&gt;
&lt;td&gt;8&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;1 byte&lt;/td&gt;
&lt;td&gt;-128 ~ 127&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;INT4&lt;/td&gt;
&lt;td&gt;4&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;0.5 byte&lt;/td&gt;
&lt;td&gt;-8 ~ 7&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;blockquote&gt;
&lt;p&gt;注意: 这里 1 byte = 8 bits, 所以一个 INT4 数是 0.5 个 byte.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;0.5 回到最初的问题&lt;/h3&gt;
&lt;p&gt;$7 \times 10^9$ 个参数, 每个参数用 BF16 存 (2 bytes/参数):&lt;/p&gt;
&lt;p&gt;$$
7 \times 10^9 \times 2 \ \text{bytes} = 14 \times 10^9 \ \text{bytes} \approx 14 \ \text{GB}
$$&lt;/p&gt;
&lt;p&gt;如果量化到 INT4 (0.5 bytes/参数):&lt;/p&gt;
&lt;p&gt;$$
7 \times 10^9 \times 0.5 \ \text{bytes} = 3.5 \times 10^9 \ \text{bytes} \approx 3.5 \ \text{GB}
$$&lt;/p&gt;
&lt;p&gt;一下子省了 4 倍的空间!&lt;/p&gt;
&lt;p&gt;于是大家想: &lt;strong&gt;能不能用更少的 bit 来存这些数字, 同时尽量不损失模型的质量?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这听起来像是一个数据压缩问题——用较少的信息表示同样的内容, 尽量保留原始信息. 但和一般的压缩（比如 zip）不同, 量化有一个重要特点: 我们不需要精确还原每个参数, 我们只需要&lt;strong&gt;最终模型输出的质量尽量不变&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;这就给了我们操作空间.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 设定目标&lt;/h2&gt;
&lt;p&gt;先想清楚我们在做什么.&lt;/p&gt;
&lt;p&gt;假设有一个权重矩阵 $W \in \mathbb{R}^{m \times n}$, 原始精度是 BF16. 我们想把它转成 INT4 来存, 即每个数只占 4 个 bit.&lt;/p&gt;
&lt;p&gt;我们的操作流程是:&lt;/p&gt;
&lt;p&gt;$$
W \xrightarrow{\text{量化}} W_q \ (\text{INT4}) \xrightarrow{\text{加载时反量化}} \hat{W}
$$&lt;/p&gt;
&lt;p&gt;推理时计算的是 $\hat{W}x$, 而不是 $Wx$. 所以误差就是:&lt;/p&gt;
&lt;p&gt;$$
\text{error} = |Wx - \hat{W}x|
$$&lt;/p&gt;
&lt;p&gt;我们希望这个误差尽量小.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;这就是量化问题的核心: 找到一个映射函数 $Q: \mathbb{R} \to \mathbb{Z}_4$ (实数到 4-bit 整数), 使得反量化后的 $\hat{W}$ 与 $W$ 的误差最小.&lt;/strong&gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;2. 最简单的尝试: 四舍五入&lt;/h2&gt;
&lt;p&gt;最直观的想法: 把每个浮点数四舍五入到最近的整数.&lt;/p&gt;
&lt;p&gt;比如 $0.3 \to 0$, $1.7 \to 2$, $-0.8 \to -1$.&lt;/p&gt;
&lt;p&gt;但马上发现问题了: 模型参数的取值范围一般是 $[-2, 2]$ 左右, 而 INT4 只能表示 $-8$ 到 $7$. 这就像用一把 16 米长的尺子去量 2 米的东西——大部分刻度都浪费了. 四舍五入后, 几乎所有数都变成 0 或 ±1, 信息全丢了.&lt;/p&gt;
&lt;p&gt;所以我们需要&lt;strong&gt;缩放&lt;/strong&gt;. 先除以一个 scale, 让数据的范围匹配整数的范围, 再四舍五入. 这就是量化的核心思想.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;3. 对称量化: 引入 scale&lt;/h2&gt;
&lt;p&gt;先考虑最简单的情况: 数据分布在 0 两侧, 大致对称.&lt;/p&gt;
&lt;h3&gt;3.1 推导 scale&lt;/h3&gt;
&lt;p&gt;我们希望把 $[-a, a]$ 范围内的浮点数映射到 $[-Q_n, Q_p]$ 范围内的整数. 对于对称 INT8, $Q_n = -127$, $Q_p = 127$.&lt;/p&gt;
&lt;p&gt;目标: 最小化量化误差. 假设有一组数值 $x_1, x_2, ..., x_N$, 我们选择的 scale 是 $s$, 量化函数是:&lt;/p&gt;
&lt;p&gt;$$
x_q = \text{round}\left(\frac{x}{s}\right)
$$&lt;/p&gt;
&lt;p&gt;反量化:&lt;/p&gt;
&lt;p&gt;$$
\hat{x} = x_q \times s
$$&lt;/p&gt;
&lt;p&gt;量化误差:&lt;/p&gt;
&lt;p&gt;$$
\text{MSE}(s) = \frac{1}{N}\sum_{i=1}^N (x_i - \hat{x}&lt;em&gt;i)^2
= \frac{1}{N}\sum&lt;/em&gt;{i=1}^N \left(x_i - s \cdot \text{round}\left(\frac{x_i}{s}\right)\right)^2
$$&lt;/p&gt;
&lt;p&gt;这个 MSE 的最小化没有完美的闭式解, 但有一个在实际中效果很好的启发式: &lt;strong&gt;让 scale 刚好覆盖数据的最大范围&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;$$
s = \frac{\max(|x|)}{Q_p}
$$&lt;/p&gt;
&lt;p&gt;为什么? 如果 $s$ 太大, 那么 $\frac{x}{s}$ 就很小, 四舍五入后很多不同的 $x$ 会映射到同一个整数, 量化颗粒度太粗. 如果 $s$ 太小, 会有一部分数据超出 $[-Q_n s, Q_p s]$ 的范围, 发生&lt;strong&gt;截断&lt;/strong&gt; (clipping), 引入截断误差.&lt;/p&gt;
&lt;p&gt;最佳 scale 就是在&lt;strong&gt;颗粒度误差&lt;/strong&gt;和&lt;strong&gt;截断误差&lt;/strong&gt;之间取平衡. 对于接近均匀分布的数据, 覆盖最大范围的 scale 是近似最优的. 但对于长尾分布 (模型权重的分布通常中间密两边疏), 选择稍微小一点的 scale (允许少量截断) 反而能降低整体 MSE——因为截断的是极少数离群点, 而保留的精度给大部分值带来了好处.&lt;/p&gt;
&lt;p&gt;我们把这个重要的权衡点记下来, 后面 AWQ 会用到.&lt;/p&gt;
&lt;h3&gt;3.2 直观例子&lt;/h3&gt;
&lt;p&gt;假设权重值: $[0.1, -0.3, 0.7, -0.9, 1.2]$, 量化到 INT8 ($Q_p = 127$).&lt;/p&gt;
&lt;p&gt;$$
s = \frac{1.2}{127} \approx 0.00945
$$&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;原始值&lt;/th&gt;
&lt;th&gt;$x/s$&lt;/th&gt;
&lt;th&gt;round&lt;/th&gt;
&lt;th&gt;反量化&lt;/th&gt;
&lt;th&gt;误差&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;0.1&lt;/td&gt;
&lt;td&gt;10.58&lt;/td&gt;
&lt;td&gt;11&lt;/td&gt;
&lt;td&gt;0.104&lt;/td&gt;
&lt;td&gt;+0.004&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;-0.3&lt;/td&gt;
&lt;td&gt;-31.75&lt;/td&gt;
&lt;td&gt;-32&lt;/td&gt;
&lt;td&gt;-0.302&lt;/td&gt;
&lt;td&gt;-0.002&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;0.7&lt;/td&gt;
&lt;td&gt;74.07&lt;/td&gt;
&lt;td&gt;74&lt;/td&gt;
&lt;td&gt;0.699&lt;/td&gt;
&lt;td&gt;-0.001&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;-0.9&lt;/td&gt;
&lt;td&gt;-95.24&lt;/td&gt;
&lt;td&gt;-95&lt;/td&gt;
&lt;td&gt;-0.898&lt;/td&gt;
&lt;td&gt;+0.002&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;1.2&lt;/td&gt;
&lt;td&gt;126.98&lt;/td&gt;
&lt;td&gt;127&lt;/td&gt;
&lt;td&gt;1.200&lt;/td&gt;
&lt;td&gt;0.000&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;可以看到误差都在 $s/2 \approx 0.005$ 以内, 相对误差不到 1%. 在小数值上的相对误差会大一些 (0.1 → 0.104, 误差 4%), 但绝对值很小.&lt;/p&gt;
&lt;h3&gt;3.3 对称量化的局限性&lt;/h3&gt;
&lt;p&gt;如果数据分布不对称怎么办? 比如 ReLU 之后的激活值全是正数 $[0, 5.0]$.&lt;/p&gt;
&lt;p&gt;用对称量化, 正半轴用 $[0, 127]$, 负半轴 $[-127, 0)$ 完全浪费了, 相当于只用了一半的精度预算. 可以用非对称量化来补救.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. 非对称量化: 引入 zero_point&lt;/h2&gt;
&lt;h3&gt;4.1 推导&lt;/h3&gt;
&lt;p&gt;非对称量化不再要求映射关于 0 对称, 而是允许平移. 引入两个参数: scale $s$ 和 zero_point $z$.&lt;/p&gt;
&lt;p&gt;量化:&lt;/p&gt;
&lt;p&gt;$$
x_q = \text{round}\left(\frac{x}{s}\right) + z
$$&lt;/p&gt;
&lt;p&gt;反量化:&lt;/p&gt;
&lt;p&gt;$$
\hat{x} = (x_q - z) \times s
$$&lt;/p&gt;
&lt;p&gt;其中 $z$ 是一个整数, 使得 $x = 0$ 映射到 $x_q = z$. 参数怎么确定?&lt;/p&gt;
&lt;p&gt;假设数据范围是 $[x_{\min}, x_{\max}]$, 量化范围是 $[0, 2^n - 1]$ (对 INT8 就是 $[0, 255]$).&lt;/p&gt;
&lt;p&gt;$$
s = \frac{x_{\max} - x_{\min}}{2^n - 1}
$$&lt;/p&gt;
&lt;p&gt;$$
z = \text{round}\left(-\frac{x_{\min}}{s}\right) = \text{round}\left(-\frac{x_{\min} \cdot (2^n - 1)}{x_{\max} - x_{\min}}\right)
$$&lt;/p&gt;
&lt;p&gt;推导: 把数据范围映射到整数范围:&lt;/p&gt;
&lt;p&gt;$$
x_{\min} \to 0, \quad x_{\max} \to 2^n - 1
$$&lt;/p&gt;
&lt;p&gt;线性映射就是:&lt;/p&gt;
&lt;p&gt;$$
x_q = \frac{x - x_{\min}}{x_{\max} - x_{\min}} \cdot (2^n - 1)
$$&lt;/p&gt;
&lt;p&gt;等价于上面 $s$ 和 $z$ 的形式 (验证一下: 代入 $x = 0$, 得到 $x_q = \frac{-x_{\min}}{x_{\max} - x_{\min}} \cdot (2^n - 1) \approx z$).&lt;/p&gt;
&lt;p&gt;把公式整理成更常用的形式:&lt;/p&gt;
&lt;p&gt;$$
x_q = \text{round}\left(\frac{x}{s}\right) + z
$$&lt;/p&gt;
&lt;p&gt;其中 $s = \frac{x_{\max} - x_{\min}}{2^n - 1}$, $z = \text{round}\left(-\frac{x_{\min}}{s}\right)$.&lt;/p&gt;
&lt;h3&gt;4.2 对称 vs 非对称, 选哪个?&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;对称&lt;/th&gt;
&lt;th&gt;非对称&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;参数数量&lt;/td&gt;
&lt;td&gt;1 个 scale&lt;/td&gt;
&lt;td&gt;2 个 (scale + zero_point)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;计算效率&lt;/td&gt;
&lt;td&gt;高 (无额外减法)&lt;/td&gt;
&lt;td&gt;低 (多一步减法)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;适合数据&lt;/td&gt;
&lt;td&gt;对称分布 (如权重)&lt;/td&gt;
&lt;td&gt;不对称分布 (如激活值)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;硬件支持&lt;/td&gt;
&lt;td&gt;几乎所有硬件&lt;/td&gt;
&lt;td&gt;部分硬件优化不足&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;实践中: &lt;strong&gt;权重用对称量化, 激活值用非对称量化&lt;/strong&gt;.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;5. 量化的误差到底从哪里来?&lt;/h2&gt;
&lt;p&gt;搞清楚了公式, 现在深入看看误差来源.&lt;/p&gt;
&lt;p&gt;把反量化后的值写成:&lt;/p&gt;
&lt;p&gt;$$
\hat{x} = s \cdot \text{round}\left(\frac{x}{s}\right) = x + s \cdot \left(\text{round}\left(\frac{x}{s}\right) - \frac{x}{s}\right)
$$&lt;/p&gt;
&lt;p&gt;括号里的项是四舍五入的误差 $e$, 范围是 $[-0.5, 0.5]$. 所以:&lt;/p&gt;
&lt;p&gt;$$
\hat{x} = x + s \cdot e, \quad e \in [-0.5, 0.5]
$$&lt;/p&gt;
&lt;p&gt;量化误差就是 $s \cdot e$, 最大值 $0.5s$. 这告诉我们两件事:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;scale 越大, 误差越大&lt;/strong&gt;. 数据范围大的张量, 量化损失也更严重.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;量化误差的上限是固定的&lt;/strong&gt; ($0.5s$), 不管原始值大小.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;第二条特别重要: 如果一个权重值本来就很小 (比如 $0.001$), 量化误差 $0.5s$ 可能跟它本身一样大! 而一个大的权重值 (比如 $1.5$), 同样 $0.5s$ 的误差占比不到 1%.&lt;/p&gt;
&lt;p&gt;这就引出了一个关键洞察: &lt;strong&gt;不是所有权重对最终结果的影响都一样大&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;计算 $Wx$ 时, 某个权重 $W_{ij}$ 对输出的贡献是 $W_{ij} \cdot x_j$. 如果 $|W_{ij}|$ 很大, 或者它对应的激活值 $|x_j|$ 很大, 那这个权重的量化误差就会被放大. 反过来, 如果一个权重很小, 即使它的相对误差很大, 对最终结果的绝对影响也很小.&lt;/p&gt;
&lt;p&gt;这个洞察是后面 AWQ 的核心.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;6. 从数学到工程: 量化粒度&lt;/h2&gt;
&lt;p&gt;上面讲的都是针对一个张量的整体量化. 但一个模型有几十上百层, 每层的参数分布可能差别很大.&lt;/p&gt;
&lt;p&gt;量化粒度就是&lt;strong&gt;每多少个参数共用一个 scale&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Per-tensor&lt;/strong&gt;: 整个权重矩阵 1 个 scale, 最简单但精度最差&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Per-channel&lt;/strong&gt;: 每个输出通道 1 个 scale (对线性层来说, 就是权重矩阵的每一行), 实际中最常用&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Per-group&lt;/strong&gt;: 每 $g$ 个参数 1 个 scale (比如 $g=32$), 精度最高但存储开销也大&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;为什么? 从信息论角度理解: 每组独立算 scale, 相当于给每组分配了独立的&quot;精度预算&quot;, 可以针对该组的实际分布来优化. 缺点是需要额外存储 scale 值 (一般用 FP16 或 FP32), group size 越小, 额外存储占比越大. 例如 group size = 32, 每 32 个值存 1 个 16-bit scale, 额外开销占 $16/32 = 0.5$ bit/参数——对于 4-bit 量化来说, 相当于增加了 12.5% 的存储.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;7. GPTQ: 有脑子的量化&lt;/h2&gt;
&lt;p&gt;上面讲的都是&quot;独立量化每个参数&quot;. GPTQ 做了一个更聪明的选择:&lt;strong&gt;量化参数时考虑参数之间的相互作用&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;7.1 从 OBS 说起&lt;/h3&gt;
&lt;p&gt;GPTQ 的前身是 &lt;strong&gt;Optimal Brain Quantizer (OBQ)&lt;/strong&gt;, 它的核心思路来自最优脑损伤 (Optimal Brain Damage) 的思想: 衡量&lt;strong&gt;移除某个参数对损失函数的影响&lt;/strong&gt;, 优先量化影响小的参数.&lt;/p&gt;
&lt;p&gt;对于一个训练好的网络, 损失函数 $\mathcal{L}$ 在最优参数 $\mathbf{w}$ 附近的二阶泰勒展开是:&lt;/p&gt;
&lt;p&gt;$$
\mathcal{L}(\mathbf{w} + \delta) \approx \mathcal{L}(\mathbf{w}) + \underbrace{\nabla \mathcal{L}(\mathbf{w})^{\mathsf{T}} \delta}_{=0 \ \text{在最优解处}} + \frac{1}{2} \delta^{\mathsf{T}} H \delta
$$&lt;/p&gt;
&lt;p&gt;其中 $H$ 是 &lt;strong&gt;Hessian 矩阵&lt;/strong&gt; (二阶偏导数矩阵, 衡量损失函数在各方向上的曲率). 对角线元素 $H_{qq}$ 越大, 说明损失函数在参数 $w_q$ 方向上越&quot;陡峭&quot;, 这个参数越重要.&lt;/p&gt;
&lt;p&gt;$$
\Delta \mathcal{L} \approx \frac{1}{2} \frac{w_q^2}{[H^{-1}]_{qq}}
$$&lt;/p&gt;
&lt;p&gt;推导: 如果我们量化 (移除) 参数 $w_q$, 最优的补偿是调整其他参数来最小化损失. OBQ 证明, 最小损失的调整量是 $\delta = -\frac{w_q}{[H^{-1}]&lt;em&gt;{qq}} H&lt;/em&gt;{:,q}^{-1}$, 对应的损失增量就是 $\frac{1}{2} \frac{w_q^2}{[H^{-1}]_{qq}}$.&lt;/p&gt;
&lt;p&gt;注意这里 $[H^{-1}]&lt;em&gt;{qq}$ 在&lt;strong&gt;分母&lt;/strong&gt;. $[H^{-1}]&lt;/em&gt;{qq}$ 的含义是&quot;移除参数 $w_q$ 后, 其他参数能补偿的程度&quot;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;如果 $[H^{-1}]_{qq}$ 很大&lt;/strong&gt;: 说明损失函数在这个方向上很&quot;平缓&quot;, 移除 $w_q$ 后其他参数很容易补偿 → 损失增量小 → 这个参数&lt;strong&gt;不重要&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;如果 $[H^{-1}]_{qq}$ 很小&lt;/strong&gt;: 说明损失函数在这个方向上很&quot;陡峭&quot;, $w_q$ 的位置很关键 → 损失增量大 → 这个参数&lt;strong&gt;很重要&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;7.2 GPTQ 的工程优化&lt;/h3&gt;
&lt;p&gt;OBQ 每次量化一个参数后要更新 Hessian 逆矩阵, 复杂度 $O(d_{\text{col}}^3)$, 对大规模模型不可行.&lt;/p&gt;
&lt;p&gt;GPTQ 做了三个关键工程优化:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;固定顺序&lt;/strong&gt;: 不再贪心选择&quot;影响最小的参数&quot;, 而是直接按列从左到右量化&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;懒惰更新&lt;/strong&gt;: 批量量化多列后再更新 Hessian, 利用矩阵乘法的 GPU 加速&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Cholesky 预分解&lt;/strong&gt;: 预先对 Hessian 做 Cholesky 分解 ($H = LL^{\mathsf{T}}$), 避免逐次求逆. Cholesky 分解是一种将对称正定矩阵分解为下三角矩阵乘以其转置的方法, 比直接求逆更高效稳定.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;最终 GPTQ 可以在几个小时内量化 175B 的模型, 且 INT4 精度损失极小.&lt;/p&gt;
&lt;p&gt;算法流程 (简化的伪代码):&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入: 权重矩阵 W, 校准数据 X, 量化精度 b
1. 计算 Hessian: H = 2 X^T X
2. 对 H 做 Cholesky 分解
3. for 每一列 j in W:
   a. 量化第 j 列: W_q[:, j]
   b. 计算量化误差: err = W[:, j] - W_q[:, j]
   c. 把误差按 Hessian 信息&quot;补偿&quot;到未量化的列上
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;步骤 3c 是关键: 量化第 1 列造成的误差, 会被分摊到后面的列上, 后面的参数会&quot;吸收&quot;前面的量化误差. 这就好比: 你前面做错了事, 后面的人帮你兜着.&lt;/p&gt;
&lt;h3&gt;7.3 GPTQ 的精度&lt;/h3&gt;
&lt;p&gt;实际经验: GPTQ INT4 的精度损失一般在 1% 以内 (以 MMLU 为基准). 对于大模型 (70B+), 损失甚至更小, 因为参数越多, 量化误差越容易被&quot;稀释&quot;.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;8. AWQ: 重要通道保护&lt;/h2&gt;
&lt;p&gt;AWQ 的出发点很简单: &lt;strong&gt;不是所有权重都值得平等对待&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;8.1 关键观察&lt;/h3&gt;
&lt;p&gt;AWQ 的作者发现了一个有趣的现象: 权重中约 1% 的通道 (channel) 对模型质量影响巨大. 这些通道的特点是——它们对应的&lt;strong&gt;激活值 (activation) 幅度特别大&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;回忆前面说的: 计算 $Wx$ 时, 某个权重 $W_{ij}$ 对输出的贡献是 $W_{ij} \cdot x_j$. 如果激活值 $x_j$ 很大, 那么这个通道的微小量化误差都会被放大.&lt;/p&gt;
&lt;h3&gt;8.2 AWQ 的做法&lt;/h3&gt;
&lt;p&gt;AWQ 的做法和直觉相反——它不直接保留这些重要通道的高精度, 而是用一个巧妙的技巧来&quot;保护&quot;它们:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;用少量校准数据跑一遍, 找到激活值幅度大的通道&lt;/li&gt;
&lt;li&gt;对这些通道的权重&lt;strong&gt;乘以一个大于 1 的缩放因子 $s$&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;量化缩放后的权重&lt;/li&gt;
&lt;li&gt;在推理时, 对相应的激活值&lt;strong&gt;除以 $s$&lt;/strong&gt;, 保证计算结果不变&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;数学上:&lt;/p&gt;
&lt;p&gt;$$ \text{量化前: } Wx \quad \longrightarrow \quad \text{量化后: } \hat{W}\left(\frac{x}{s}\right) \cdot s $$&lt;/p&gt;
&lt;p&gt;其中 $\hat{W}$ 是 $W \cdot s$ 的量化版本. 由于 $|W \cdot s|$ 变大了, 在同样的量化范围下, scale 相应增大, 但更重要的是: &lt;strong&gt;重要通道的权重被放大后, 量化相对误差变小了&lt;/strong&gt; (因为四舍五入的绝对误差 $0.5 \cdot s_{\text{quant}}$ 相对于放大后的值变小).&lt;/p&gt;
&lt;p&gt;这就像在教育预算中, 给更需要支持的学校多分配资源——总量不变, 但分配更合理.&lt;/p&gt;
&lt;h3&gt;8.3 AWQ vs GPTQ&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;GPTQ&lt;/th&gt;
&lt;th&gt;AWQ&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;核心思路&lt;/td&gt;
&lt;td&gt;量化误差补偿&lt;/td&gt;
&lt;td&gt;保护重要通道&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;需要校准数据&lt;/td&gt;
&lt;td&gt;是&lt;/td&gt;
&lt;td&gt;是&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;计算开销&lt;/td&gt;
&lt;td&gt;高 (需要 Hessian)&lt;/td&gt;
&lt;td&gt;低 (只需要激活值统计)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;精度 (INT4)&lt;/td&gt;
&lt;td&gt;略好&lt;/td&gt;
&lt;td&gt;相当&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;主要贡献&lt;/td&gt;
&lt;td&gt;大规模高效量化&lt;/td&gt;
&lt;td&gt;简单有效的重要性感知&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;两种方法可以结合使用, 不少实际部署方案会先用 AWQ 的思路找重要通道, 再用 GPTQ 的方法做量化.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;9. 总结&lt;/h2&gt;
&lt;p&gt;回头看看量化到底在做什么:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;核心思想&lt;/strong&gt;: 用更少的 bit 存数字, 最小化对输出的影响.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;数学本质&lt;/strong&gt;: 找到缩放 $s$ 和偏移 $z$, 使得 $\hat{x} = s \cdot \left(\text{round}\left(\frac{x}{s}\right) + z\right)$ 的误差最小.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;关键权衡&lt;/strong&gt;:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;权衡&lt;/th&gt;
&lt;th&gt;选择&lt;/th&gt;
&lt;th&gt;效果&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;scale 大小&lt;/td&gt;
&lt;td&gt;大 → 覆盖范围大但精度粗&lt;/td&gt;
&lt;td&gt;小 → 精度细但可能截断&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;量化粒度&lt;/td&gt;
&lt;td&gt;粗 → 速度飞快精度低&lt;/td&gt;
&lt;td&gt;细 → 精度高但开销大&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;量化方法&lt;/td&gt;
&lt;td&gt;独立量化 → 简单&lt;/td&gt;
&lt;td&gt;关联量化(GPTQ) → 精度好但慢&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;通道保护&lt;/td&gt;
&lt;td&gt;一视同仁 → 简单&lt;/td&gt;
&lt;td&gt;区别对待(AWQ) → 精度更好&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;&lt;strong&gt;一句话&lt;/strong&gt;: 量化就是用精度换效率, 关键是怎么在尽可能少丢精度的情况下, 做到极致的效率.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;Gray &amp;amp; Neuhoff, Quantization. IEEE Transactions on Information Theory, 1998. — &lt;strong&gt;量化信息论的经典综述&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers. Frantar et al., ICLR 2023. &lt;a href=&quot;https://arxiv.org/abs/2210.17323&quot;&gt;arXiv:2210.17323&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration. Lin et al., MLSys 2024. &lt;a href=&quot;https://arxiv.org/abs/2306.00978&quot;&gt;arXiv:2306.00978&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Nagel et al., A White Paper on Neural Network Quantization. &lt;a href=&quot;https://arxiv.org/abs/2106.08295&quot;&gt;arXiv:2106.08295&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;llama.cpp GGUF format. &lt;a href=&quot;https://github.com/ggganov/llama.cpp&quot;&gt;GitHub&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>RoPE 旋转位置编码 — 从目标出发，一步步推出旋转矩阵</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-rope-rotary-position-embedding/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-rope-rotary-position-embedding/</guid><description>不直接给公式，而是从&quot;我希望内积只依赖相对位置&quot;这个目标出发，一步步反推出旋转矩阵形式的推导过程。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;在正式推导之前, 我们先想清楚一件事: &lt;strong&gt;我们到底希望位置编码做到什么?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;回忆一下, Transformer 的 Self-Attention 计算的是 query 和 key 的内积:&lt;/p&gt;
&lt;p&gt;$$
\text{score} = \mathbf{q}^{\mathsf{T}} \mathbf{k}
$$&lt;/p&gt;
&lt;p&gt;问题是这个分数跟位置完全无关。把词序打乱, attention 输出一样。&lt;/p&gt;
&lt;p&gt;所以我们需要给每个位置上的 query 和 key 加上位置信息, 让 attention 能感知到&quot;谁在哪儿&quot;。&lt;/p&gt;
&lt;p&gt;但&quot;加上位置信息&quot;这个说法太笼统了。更精确地说, 我们希望:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;编码位置后, 两个 token 的 attention score &lt;strong&gt;只依赖于它们的相对位置差&lt;/strong&gt;, 而不依赖于它们各自的绝对位置。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这是整篇文章推导的起点。搞清楚了这一点, 后面的所有公式就有了方向。&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 设定目标&lt;/h2&gt;
&lt;p&gt;把位置编码表示成两个函数:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$f_q(\mathbf{q}, m)$: 对位置 $m$ 处的 query 向量 $\mathbf{q}$ 进行位置编码&lt;/li&gt;
&lt;li&gt;$f_k(\mathbf{k}, n)$: 对位置 $n$ 处的 key 向量 $\mathbf{k}$ 进行位置编码&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;我们期望:&lt;/p&gt;
&lt;p&gt;$$
\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle = g(\mathbf{q}, \mathbf{k}, m - n)
$$&lt;/p&gt;
&lt;p&gt;即: 编码后的内积, &lt;strong&gt;只跟词本身($\mathbf{q}$, $\mathbf{k}$)和相对位置($m-n$)有关, 跟绝对位置 $m$, $n$ 无关&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;这就是 RoPE 的推导起点. 接下来我们要做的事是: &lt;strong&gt;什么样的 $f_q$, $f_k$ 才能满足这个条件?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-derivation-flow.png&quot; alt=&quot;推导路线图&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;上图为 RoPE 的完整推导路线图. 从目标出发, 通过复数/三角函数, 最终推出旋转矩阵.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;2. 从最简单的 2 维情况开始&lt;/h2&gt;
&lt;p&gt;为了找到 $f_q$ 和 $f_k$ 的形式, 我们先考虑最简单的情况: $\mathbf{q}$ 和 $\mathbf{k}$ 是 2 维向量.&lt;/p&gt;
&lt;h3&gt;2.1 用复数表示 2D 向量&lt;/h3&gt;
&lt;p&gt;2 维向量可以用复数表示:&lt;/p&gt;
&lt;p&gt;$$
\mathbf{q} = (q_1, q_2) \longrightarrow \tilde{q} = q_1 + i q_2
$$&lt;/p&gt;
&lt;p&gt;$$
\mathbf{k} = (k_1, k_2) \longrightarrow \tilde{k} = k_1 + i k_2
$$&lt;/p&gt;
&lt;p&gt;复数的好处是: &lt;strong&gt;旋转和缩放可以简洁地用乘法表示&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;2.2 一个关键的观察&lt;/h3&gt;
&lt;p&gt;在复数域中, 两个复数的&quot;内积&quot;(取实部)可以写成:&lt;/p&gt;
&lt;p&gt;$$
\langle \mathbf{q}, \mathbf{k} \rangle = \text{Re}[\tilde{q} \cdot \overline{\tilde{k}}]
$$&lt;/p&gt;
&lt;p&gt;其中 $\overline{\tilde{k}} = k_1 - i k_2$ 是共轭复数.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;验证: $\tilde{q} \cdot \overline{\tilde{k}} = (q_1 + i q_2)(k_1 - i k_2) = (q_1 k_1 + q_2 k_2) + i(q_2 k_1 - q_1 k_2)$, 取实部正好是 $\mathbf{q}^{\mathsf{T}} \mathbf{k}$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;2.3 假设一个形式, 然后验证&lt;/h3&gt;
&lt;p&gt;有了这个工具, 我们来&lt;strong&gt;假设&lt;/strong&gt;一种编码方式:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;假如我们在复数域上给向量乘以一个&lt;strong&gt;单位模长&lt;/strong&gt;的复数因子来编码位置.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;也就是说, 对位置 $m$ 处的向量 $\tilde{q}$, 我们给它在复数域上转一个角度:&lt;/p&gt;
&lt;p&gt;$$
\tilde{q}_m = \tilde{q} \cdot e^{i m \theta}
$$&lt;/p&gt;
&lt;p&gt;同理:&lt;/p&gt;
&lt;p&gt;$$
\tilde{k}_n = \tilde{k} \cdot e^{i n \theta}
$$&lt;/p&gt;
&lt;p&gt;其中 $\theta$ 是一个预置的角度参数.&lt;/p&gt;
&lt;p&gt;现在来验证这个假设是否满足我们的&lt;strong&gt;目标&lt;/strong&gt;(内积只依赖相对位置):&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\langle \tilde{q}_m, \tilde{k}_n \rangle &amp;amp;= \text{Re}[\tilde{q}_m \cdot \overline{\tilde{k}_n}] \
&amp;amp;= \text{Re}[\tilde{q} e^{i m \theta} \cdot \overline{\tilde{k}} e^{-i n \theta}] \
&amp;amp;= \text{Re}[\tilde{q} \overline{\tilde{k}} \cdot e^{i (m-n) \theta}]
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;结果只依赖于 $(m-n)$!&lt;/strong&gt; 完美满足目标.&lt;/p&gt;
&lt;h3&gt;2.4 从复数回到实数矩阵&lt;/h3&gt;
&lt;p&gt;现在把复数形式翻译回实数向量和矩阵.&lt;/p&gt;
&lt;p&gt;利用欧拉公式 $e^{i m\theta} = \cos m\theta + i \sin m\theta$, 展开复数乘法:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
q_1&apos; + i q_2&apos; &amp;amp;= (q_1 \cos m\theta - q_2 \sin m\theta) + i(q_1 \sin m\theta + q_2 \cos m\theta)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;写成矩阵形式就是:&lt;/p&gt;
&lt;p&gt;$$
\begin{pmatrix} q_1&apos; \ q_2&apos; \end{pmatrix} =
\begin{pmatrix} \cos m\theta &amp;amp; -\sin m\theta \ \sin m\theta &amp;amp; \cos m\theta \end{pmatrix}
\begin{pmatrix} q_1 \ q_2 \end{pmatrix}
$$&lt;/p&gt;
&lt;p&gt;中间这个矩阵——就是&lt;strong&gt;旋转矩阵&lt;/strong&gt; $R(m\theta)$.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-complex-to-matrix.png&quot; alt=&quot;复数乘法 → 旋转矩阵&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;图中展示了复数平面上的旋转如何对应到实数平面的旋转矩阵. 向量旋转 $m\theta$ 等价于乘以旋转矩阵 $R(m\theta)$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;所以:&lt;/p&gt;
&lt;p&gt;$$
f_q(\mathbf{q}, m) = R(m\theta) \cdot \mathbf{q}, \quad
f_k(\mathbf{k}, n) = R(n\theta) \cdot \mathbf{k}
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;推导的关键&lt;/strong&gt;: 不是凭空定义了旋转矩阵, 而是从&quot;内积只依赖相对位置&quot;这个目标出发, 通过复数域的自然假设, 反推出了旋转矩阵的形式.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;3. 验证: 旋转矩阵形式的内积&lt;/h2&gt;
&lt;p&gt;有了旋转矩阵, 我们反过来验证一下内积.&lt;/p&gt;
&lt;p&gt;在验证之前, 先确认旋转矩阵的两个重要性质:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;性质 1: $R(\theta)$ 是正交矩阵, 且 $R(\theta)^T = R(-\theta)$&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;$$
R(\theta) = \begin{pmatrix} \cos\theta &amp;amp; -\sin\theta \ \sin\theta &amp;amp; \cos\theta \end{pmatrix}, \quad
R(\theta)^T = \begin{pmatrix} \cos\theta &amp;amp; \sin\theta \ -\sin\theta &amp;amp; \cos\theta \end{pmatrix}
$$&lt;/p&gt;
&lt;p&gt;验算 $R(\theta)^T R(\theta)$:&lt;/p&gt;
&lt;p&gt;$$
R(\theta)^T R(\theta) = \begin{pmatrix} \cos^2\theta + \sin^2\theta &amp;amp; -\cos\theta\sin\theta + \sin\theta\cos\theta \ -\sin\theta\cos\theta + \cos\theta\sin\theta &amp;amp; \sin^2\theta + \cos^2\theta \end{pmatrix} = \begin{pmatrix} 1 &amp;amp; 0 \ 0 &amp;amp; 1 \end{pmatrix} = I
$$&lt;/p&gt;
&lt;p&gt;同时 $R(-\theta) = \begin{pmatrix} \cos(-\theta) &amp;amp; -\sin(-\theta) \ \sin(-\theta) &amp;amp; \cos(-\theta) \end{pmatrix} = \begin{pmatrix} \cos\theta &amp;amp; \sin\theta \ -\sin\theta &amp;amp; \cos\theta \end{pmatrix} = R(\theta)^T$, 所以 $R(\theta)^T = R(-\theta) = R(\theta)^{-1}$.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;性质 2: $R(\alpha)R(\beta) = R(\alpha+\beta)$&lt;/strong&gt; (旋转可加性)&lt;/p&gt;
&lt;p&gt;把两个旋转矩阵相乘:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
R(\alpha)R(\beta) &amp;amp;= \begin{pmatrix} \cos\alpha &amp;amp; -\sin\alpha \ \sin\alpha &amp;amp; \cos\alpha \end{pmatrix}
\begin{pmatrix} \cos\beta &amp;amp; -\sin\beta \ \sin\beta &amp;amp; \cos\beta \end{pmatrix} \
&amp;amp;= \begin{pmatrix} \cos\alpha\cos\beta - \sin\alpha\sin\beta &amp;amp; -\cos\alpha\sin\beta - \sin\alpha\cos\beta \
\sin\alpha\cos\beta + \cos\alpha\sin\beta &amp;amp; -\sin\alpha\sin\beta + \cos\alpha\cos\beta \end{pmatrix} \
&amp;amp;= \begin{pmatrix} \cos(\alpha+\beta) &amp;amp; -\sin(\alpha+\beta) \ \sin(\alpha+\beta) &amp;amp; \cos(\alpha+\beta) \end{pmatrix} = R(\alpha+\beta)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;第二步用了三角函数的和角公式: $\cos(\alpha+\beta) = \cos\alpha\cos\beta - \sin\alpha\sin\beta$, $\sin(\alpha+\beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta$.&lt;/p&gt;
&lt;p&gt;好, 有了这两条性质, 验证内积就很简单了. 不过先直观感受一下 RoPE 是怎么工作的:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-2d-rotation.png&quot; alt=&quot;2D 旋转编码位置&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;左图: 原始向量 q 和 k. 中图: q 旋转 45°(编码位置 m=1), k 旋转 15°(编码位置 n=1). 右图: 同样两个向量, 但绝对位置更大(m=100, n=70), 因为相对差相同(都是 30°), 所以内积不变. 这就是&quot;只依赖相对位置&quot;的直观体现.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;现在用数学验证:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle
&amp;amp;= (R(m\theta)\mathbf{q})^{\mathsf{T}} (R(n\theta)\mathbf{k}) \
&amp;amp;= \mathbf{q}^{\mathsf{T}} R(m\theta)^{\mathsf{T}} R(n\theta) \mathbf{k} \
&amp;amp;= \mathbf{q}^{\mathsf{T}} R(-m\theta) R(n\theta) \mathbf{k} \quad (\text{旋转矩阵正交: } R^{\mathsf{T}} = R^{-1} = R(-\theta)) \
&amp;amp;= \mathbf{q}^{\mathsf{T}} R((n-m)\theta) \mathbf{k} \quad (\text{旋转矩阵可加: } R(\alpha)R(\beta) = R(\alpha+\beta)) \
&amp;amp;= \mathbf{q}^{\mathsf{T}} R(-(m-n)\theta) \mathbf{k}
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;展开 $R(-(m-n)\theta)$:&lt;/p&gt;
&lt;p&gt;$$
R(-(m-n)\theta) =
\begin{pmatrix}
\cos(m-n)\theta &amp;amp; \sin(m-n)\theta \
-\sin(m-n)\theta &amp;amp; \cos(m-n)\theta
\end{pmatrix}
$$&lt;/p&gt;
&lt;p&gt;最后的结果只跟 $m-n$ 有关, 验证通过.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. 为什么需要不同频率&lt;/h2&gt;
&lt;p&gt;假设我们只有一个固定频率 $\theta$.&lt;/p&gt;
&lt;p&gt;对于位置 $m$ 和 $m+1$, 旋转角度分别是 $m\theta$ 和 $(m+1)\theta$, 差值是 $\theta$. 但问题在于: 当 $m$ 很大时会发生什么?&lt;/p&gt;
&lt;p&gt;来看一个具体例子. 假设 $\theta = 0.1$, 那么:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;位置 $0$: 旋转 $0^\circ$, 向量不变&lt;/li&gt;
&lt;li&gt;位置 $1$: 旋转 $5.7^\circ$&lt;/li&gt;
&lt;li&gt;...&lt;/li&gt;
&lt;li&gt;位置 $30$: 旋转 $171.9^\circ$&lt;/li&gt;
&lt;li&gt;位置 $31$: 旋转 $177.6^\circ$&lt;/li&gt;
&lt;li&gt;位置 $63$: 旋转 $361.0^\circ$ — 转了一整圈多!&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$R(63 \times 0.1) = R(6.3) \approx R(6.3-2\pi) = R(0.02)$, 所以 $m=63$ 和 $m=0$ 几乎有&lt;strong&gt;相同的旋转矩阵&lt;/strong&gt;. 位置 $63$ 和位置 $0$ 在 attention 计算中无法区分!&lt;/p&gt;
&lt;p&gt;也就是说, 由于 $\sin$ 和 $\cos$ 是周期函数, &lt;strong&gt;旋转超过一周后, 位置信息就混叠了&lt;/strong&gt;. 如果词序列很长, 后面的位置会周期性地&quot;穿越&quot;回前面.&lt;/p&gt;
&lt;p&gt;所以我们需要&lt;strong&gt;多个频率&lt;/strong&gt;, 让不同的维度对以不同的速度旋转:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;低维度: 旋转速度快 ($\theta_i$ 大), 能区分&lt;strong&gt;精细位置&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;高维度: 旋转速度慢 ($\theta_i$ 小), 能感知&lt;strong&gt;大范围距离&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;频率设置和 Sinusoidal PE 一样:&lt;/p&gt;
&lt;p&gt;$$
\theta_i = 10000^{-2i/d}, \quad i = 0, 1, ..., d/2 - 1
$$&lt;/p&gt;
&lt;p&gt;这个公式的效果是: $i$ 越小, $\theta_i$ 越大(旋转越快), 反之亦然.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-multi-freq-rotation.png&quot; alt=&quot;不同频率下的旋转速度对比&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;图中展示了三个不同频率的旋转速度. 红色转得最快(低维度), 蓝色次之, 绿色最慢(高维度).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;5. 从 2D 扩展到高维&lt;/h2&gt;
&lt;p&gt;既然 2D 情况下的位置编码是旋转, 那高维呢?&lt;/p&gt;
&lt;p&gt;答案很自然: &lt;strong&gt;把向量切成 $d/2$ 对, 每对独立旋转&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;对于 8 维向量, 它的旋转矩阵是一个&lt;strong&gt;分块对角矩阵&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;$$
R_m =
\begin{pmatrix}
R(m\theta_0) &amp;amp; 0 &amp;amp; 0 &amp;amp; 0 \
0 &amp;amp; R(m\theta_1) &amp;amp; 0 &amp;amp; 0 \
0 &amp;amp; 0 &amp;amp; R(m\theta_2) &amp;amp; 0 \
0 &amp;amp; 0 &amp;amp; 0 &amp;amp; R(m\theta_3)
\end{pmatrix}
$$&lt;/p&gt;
&lt;p&gt;其中每个 $R(m\theta_i)$ 是 2×2 的旋转矩阵.&lt;/p&gt;
&lt;p&gt;这个分块矩阵作用于 8 维向量时, &lt;strong&gt;第 1-2 维用 $\theta_0$, 第 3-4 维用 $\theta_1$, 以此类推&lt;/strong&gt;. 每对维度独立旋转, 互不干扰.&lt;/p&gt;
&lt;p&gt;写成公式就是:&lt;/p&gt;
&lt;p&gt;$$
f_q(\mathbf{q}, m)&lt;em&gt;{(2i, 2i+1)} =
\begin{pmatrix}
q&lt;/em&gt;{2i} \cos m\theta_i - q_{2i+1} \sin m\theta_i \
q_{2i} \sin m\theta_i + q_{2i+1} \cos m\theta_i
\end{pmatrix}
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么可以这样做?&lt;/strong&gt; 这里需要一点简单的推导。&lt;/p&gt;
&lt;p&gt;假设我们把 $d$ 维向量 $\mathbf{q}$ 切分成 $d/2$ 个 2 维子向量 $\mathbf{q}^{(0)}, \mathbf{q}^{(1)}, ..., \mathbf{q}^{(d/2-1)}$, 其中 $\mathbf{q}^{(i)} = (q_{2i}, q_{2i+1})$. 那么内积可以写成:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\langle \mathbf{q}, \mathbf{k} \rangle
&amp;amp;= \sum_{j=1}^{d} q_j k_j \
&amp;amp;= \sum_{i=0}^{d/2-1} (q_{2i} k_{2i} + q_{2i+1} k_{2i+1}) \
&amp;amp;= \sum_{i=0}^{d/2-1} \langle \mathbf{q}^{(i)}, \mathbf{k}^{(i)} \rangle
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;这个式子成立, 仅仅是因为&lt;strong&gt;内积的定义就是对应位置相乘再求和&lt;/strong&gt;, 我们可以自由地按任何顺序分组求和——交换律和结合律而已. 这个证明不需要任何额外的数学知识, 就是最基础的向量内积定义.&lt;/p&gt;
&lt;p&gt;现在, 如果我们对每个子向量 $\mathbf{q}^{(i)}$ 独立施加旋转 $R(m\theta_i)$, 那么编码后的内积就是:&lt;/p&gt;
&lt;p&gt;$$
\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle
= \sum_{i=0}^{d/2-1} \langle R(m\theta_i) \mathbf{q}^{(i)}, R(n\theta_i) \mathbf{k}^{(i)} \rangle
$$&lt;/p&gt;
&lt;p&gt;我们在 2D 情况下已经证明过, 每个子空间的内积都只依赖于 $m-n$:&lt;/p&gt;
&lt;p&gt;$$
\langle R(m\theta_i) \mathbf{q}^{(i)}, R(n\theta_i) \mathbf{k}^{(i)} \rangle
= \langle \mathbf{q}^{(i)}, R((n-m)\theta_i) \mathbf{k}^{(i)} \rangle
$$&lt;/p&gt;
&lt;p&gt;所以整个内积也只依赖 $m-n$. &lt;strong&gt;证毕&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;这样我们就实现了 $d$ 维向量到 $d/2$ 个独立 2D 旋转的分解.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-block-diagonal.png&quot; alt=&quot;分块旋转矩阵可视化&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;图中展示了 8 维向量被分成 4 对, 每对用不同的频率独立旋转, 整体构成块对角矩阵.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;6. 高维验证小结&lt;/h2&gt;
&lt;p&gt;上面的推导已经证明: 高维内积可以分解为：&lt;/p&gt;
&lt;p&gt;$$ \langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle = \sum_{i=0}^{d/2-1} \langle \mathbf{q}^{(i)}, R((n-m)\theta_i) \mathbf{k}^{(i)} \rangle $$&lt;/p&gt;
&lt;p&gt;其中 $\mathbf{q}^{(i)} = (q_{2i}, q_{2i+1})$ 是第 $i$ 个子空间的 2 维向量. 每个子空间的结果都只依赖于 $m-n$, 所以总和也只依赖 $m-n$.&lt;/p&gt;
&lt;p&gt;这就是 RoPE 在高维下的完整形式.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;7. 长距离衰减&lt;/h2&gt;
&lt;p&gt;前面我们从目标出发推导了 RoPE 的形式. 现在来看 RoPE 自带的一个优雅性质: &lt;strong&gt;长距离衰减&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;假设 $\mathbf{q}$ 和 $\mathbf{k}$ 是来自同一分布的随机向量, 各分量均值为 0, 方差为 1, 且不同分量之间相互独立. 这意味着我们希望&lt;strong&gt;相关系数&lt;/strong&gt;: 当 $j = l$ 时 $\mathbb{E}[q_j k_l] = 1$, 当 $j \neq l$ 时 $\mathbb{E}[q_j k_l] = 0$.  用 Kronecker delta 符号 $\delta_{jl}$ 统一表示就是 $\mathbb{E}[q_j k_l] = \delta_{jl}$ (即 $j=l$ 时为 1, 否则为 0).&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;$\delta_{jl}$ 是 Kronecker delta: $\delta_{jl}=1$ 当 $j=l$, 否则 $\delta_{jl}=0$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;现在计算 RoPE 编码后内积的期望:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\mathbb{E}[\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle]
&amp;amp;= \mathbb{E}\left[ \sum_{i=0}^{d/2-1} \langle R(m\theta_i) \mathbf{q}^{(i)}, R(n\theta_i) \mathbf{k}^{(i)} \rangle \right] \
&amp;amp;= \sum_{i=0}^{d/2-1} \mathbb{E}\left[ \langle R(m\theta_i) \mathbf{q}^{(i)}, R(n\theta_i) \mathbf{k}^{(i)} \rangle \right]
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;展开每个子空间的内积:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\langle R(m\theta_i) \mathbf{q}^{(i)}, R(n\theta_i) \mathbf{k}^{(i)} \rangle
&amp;amp;= (q_{2i}\cos m\theta_i - q_{2i+1}\sin m\theta_i)(k_{2i}\cos n\theta_i - k_{2i+1}\sin n\theta_i) \
&amp;amp;\quad + (q_{2i}\sin m\theta_i + q_{2i+1}\cos m\theta_i)(k_{2i}\sin n\theta_i + k_{2i+1}\cos n\theta_i)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;展开后有四项含 $q_{2i}k_{2i}$, 四项含 $q_{2i+1}k_{2i+1}$, 以及交叉项 $q_{2i}k_{2i+1}$ 和 $q_{2i+1}k_{2i}$.&lt;/p&gt;
&lt;p&gt;由于 $q$ 和 $k$ 的不同分量独立且均值为 0, 交叉项的期望为 0. 只有 $q_{2i}k_{2i}$ 和 $q_{2i+1}k_{2i+1}$ 的期望为 1:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\mathbb{E}[q_{2i}k_{2i}] &amp;amp;\times (\cos m\theta_i\cos n\theta_i + \sin m\theta_i\sin n\theta_i) \
&amp;amp;+ \mathbb{E}[q_{2i+1}k_{2i+1}] \times (\sin m\theta_i\sin n\theta_i + \cos m\theta_i\cos n\theta_i) \
&amp;amp;= 2\cos((m-n)\theta_i)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;其中用了三角恒等式 $\cos m\theta_i\cos n\theta_i + \sin m\theta_i\sin n\theta_i = \cos((m-n)\theta_i)$.&lt;/p&gt;
&lt;p&gt;所以每个子空间贡献 $2\cos((m-n)\theta_i)$. 把所有 $d/2$ 个子空间加起来:&lt;/p&gt;
&lt;p&gt;$$
\mathbb{E}[\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle] = 2\sum_{i=0}^{d/2-1} \cos((m-n)\theta_i)
$$&lt;/p&gt;
&lt;p&gt;这个求和函数在 $m-n$ 增大时会呈现&lt;strong&gt;震荡衰减&lt;/strong&gt;的趋势——距离越远, 预期的注意力分数越低. 这符合我们在自然语言中的直觉:&lt;strong&gt;相邻词通常比远距离的词关联更紧密&lt;/strong&gt;, 而且&quot;带震荡&quot;的性质意味着某些特定距离的 token 也能获得较强注意力, 这与现实中周期性短语(如每第 n 个词)的匹配模式一致.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/images/llm-series/rope-decay.png&quot; alt=&quot;RoPE 长距离衰减&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;图中展示了不同维度下 RoPE 内积随距离的衰减趋势. 注意衰减不是单调的, 而是带有震荡的.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;8. 30 行代码实现&lt;/h2&gt;
&lt;pre&gt;&lt;code&gt;import torch

def precompute_rope_frequencies(dim: int, max_len: int, base: int = 10000):
    &quot;&quot;&quot;预计算所有位置的 sin/cos 值&quot;&quot;&quot;
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_len).float()
    angles = positions[:, None] * inv_freq[None, :]         # (max_len, dim/2)
    angles = torch.cat([angles, angles], dim=-1)            # (max_len, dim)
    return angles.cos(), angles.sin()

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    &quot;&quot;&quot;对 x 施加 RoPE 旋转&quot;&quot;&quot;
    # x: (batch, seq_len, head, dim)
    # cos, sin: (seq_len, dim)
    cos = cos[None, :, None, :]    # (1, seq_len, 1, dim)
    sin = sin[None, :, None, :]    # (1, seq_len, 1, dim)
    # 每对 (x_{2i}, x_{2i+1}) 交换并取反 = 旋转
    x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).reshape(x.shape)
    return x * cos + x_rot * sin
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;使用:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;cos, sin = precompute_rope_frequencies(dim=128, max_len=4096)
q_rotated = apply_rope(q, cos, sin)
k_rotated = apply_rope(k, cos, sin)
attn = torch.matmul(q_rotated, k_rotated.transpose(-2, -1))
&lt;/code&gt;&lt;/pre&gt;
&lt;hr /&gt;
&lt;h2&gt;9. 总结&lt;/h2&gt;
&lt;p&gt;回头看整个推导过程, 最优雅的地方在于: &lt;strong&gt;我们没有&quot;发明&quot;旋转矩阵, 而是从目标出发&quot;发现&quot;了它&lt;/strong&gt;.&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;步骤&lt;/th&gt;
&lt;th&gt;思路&lt;/th&gt;
&lt;th&gt;数学形式&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;① 设定目标&lt;/td&gt;
&lt;td&gt;内积只依赖相对位置&lt;/td&gt;
&lt;td&gt;$\langle f_q(m), f_k(n) \rangle = g(m-n)$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;② 尝试复数&lt;/td&gt;
&lt;td&gt;2D 向量用复数表示&lt;/td&gt;
&lt;td&gt;$\tilde{q} = q_1 + i q_2$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;③ 假设旋转&lt;/td&gt;
&lt;td&gt;乘以单位复数编码位置&lt;/td&gt;
&lt;td&gt;$\tilde{q}_m = \tilde{q} \cdot e^{im\theta}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;④ 验证目标&lt;/td&gt;
&lt;td&gt;内积只含相对项&lt;/td&gt;
&lt;td&gt;$\text{Re}[\tilde{q}\bar{\tilde{k}} e^{i(m-n)\theta}]$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;⑤ 回到实数&lt;/td&gt;
&lt;td&gt;发现旋转矩阵&lt;/td&gt;
&lt;td&gt;$R(m\theta) = \begin{pmatrix} \cos &amp;amp; -\sin \ \sin &amp;amp; \cos \end{pmatrix}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;⑥ 多频率&lt;/td&gt;
&lt;td&gt;不同维度不同速度&lt;/td&gt;
&lt;td&gt;$\theta_i = 10000^{-2i/d}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;⑦ 高维扩展&lt;/td&gt;
&lt;td&gt;块对角旋转矩阵&lt;/td&gt;
&lt;td&gt;每对独立旋转&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;RoPE 现在已经是 LLaMA、Mistral、Gemma 等主流大模型的标配位置编码. 理解它的推导过程, 对理解后面长上下文扩展(PI、NTK、YaRN)也大有帮助.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;苏剑林. (2021). &quot;Transformer升级之路：2、博采众长的旋转式位置编码&quot;. &lt;a href=&quot;https://kexue.fm/archives/8265&quot;&gt;科学空间&lt;/a&gt; — &lt;strong&gt;本文推导思路完全参考该文&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Su et al., RoFormer: Enhanced Transformer with Rotary Position Embedding. Neurocomputing 2022. &lt;a href=&quot;https://arxiv.org/abs/2104.09864&quot;&gt;arXiv:2104.09864&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Vaswani et al., Attention Is All You Need. NeurIPS 2017. &lt;a href=&quot;https://arxiv.org/abs/1706.03762&quot;&gt;arXiv:1706.03762&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>大模型 Decoding 策略 — 从&quot;我该怎么选&quot;出发，一步步推出一套选词方案</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-decoding-strategies/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-decoding-strategies/</guid><description>不直接罗列方法，而是从&quot;模型给出概率后，怎么从中选一个词&quot;这个最朴素的问题出发，一步步推出 temperature、top-k、top-p、beam search 的原理。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;语言模型在生成文本时, 每一步都会输出一个&lt;strong&gt;概率分布&lt;/strong&gt;——一个长度为词表大小 $V$ 的向量, 每个元素表示下一个词是第 $i$ 个 token 的概率:&lt;/p&gt;
&lt;p&gt;$$
P(w_t | w_{&amp;lt;t}) = \text{softmax}(\mathbf{z}&lt;em&gt;t) = \frac{e^{z&lt;/em&gt;{t,i}}}{\sum_{j=1}^{V} e^{z_{t,j}}}
$$&lt;/p&gt;
&lt;p&gt;但问题是: &lt;strong&gt;有了这个分布之后, 该怎么选一个词出来?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这个看似简单的问题, 答案并不唯一. 不同的&quot;选法&quot;会产生完全不同的效果——有的输出死板重复, 有的天马行空, 有的稳定可靠. 整个 Decoding Strategy 领域, 本质就是在问一个问题:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;给定一个概率分布, 怎么选词才能让整体输出既&lt;strong&gt;准确&lt;/strong&gt;又&lt;strong&gt;多样&lt;/strong&gt;?&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这两个目标其实是矛盾的: 越准确(选概率最高的), 就越缺乏多样性; 越多样(从分布里随机抽), 就越容易跑偏. 不同的策略就是在权衡这两个目标.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 最自然的想法: 选概率最大的&lt;/h2&gt;
&lt;p&gt;让我们从一个最朴素的想法开始: 每次选概率最大的那个词.&lt;/p&gt;
&lt;p&gt;$$
\hat{w}&lt;em&gt;t = \arg\max_w P(w | w&lt;/em&gt;{&amp;lt;t})
$$&lt;/p&gt;
&lt;p&gt;这叫做 &lt;strong&gt;Greedy Search&lt;/strong&gt; — 贪心搜索.&lt;/p&gt;
&lt;p&gt;这是最理性的选择吗? 单步来看, 是的——$P(w_t|w_{&amp;lt;t})$ 最大意味着在已知上下文下, 这个词是最&quot;合理&quot;的. 但问题出在长远.&lt;/p&gt;
&lt;p&gt;来看一个简单的例子. 假设模型要补全 &quot;我喜欢的食物是____&quot;, 各方案的概率:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&quot;吃&quot; → 0.35 (这个词在这里确实很合理)&lt;/li&gt;
&lt;li&gt;&quot;披萨&quot; → 0.28&lt;/li&gt;
&lt;li&gt;&quot;尝&quot; → 0.12&lt;/li&gt;
&lt;li&gt;...&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Greedy 会选 &quot;吃&quot;. 但接下来呢? &quot;我喜欢的食物是吃...&quot; — 这句话感觉不对了. 而如果第一步选了 &quot;披萨&quot;, 后面可以接 &quot;很好吃&quot;, &quot;我最常点的&quot; 等等, 整句话的质量更高.&lt;/p&gt;
&lt;p&gt;所以 Greedy 的问题本质上是&lt;strong&gt;短视&lt;/strong&gt;: 它只考虑了局部最优, 没有考虑今天的决策对未来的影响.&lt;/p&gt;
&lt;p&gt;更严重的是, Greedy 很容易陷入&lt;strong&gt;重复循环&lt;/strong&gt;: &quot;很好 → 很好 → 很好...&quot; 或者 &quot;这是一本有趣的书 → 这是一本有趣的书 → ...&quot;. 因为一旦进入某个&quot;安全&quot;的局部模式, 每一步概率最高的词就是继续重复自己.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;2. 另一个极端的想法: 按概率随机抽&lt;/h2&gt;
&lt;p&gt;既然&quot;每次都选最合理的&quot;会导致死板重复, 那反过来: &lt;strong&gt;按概率来随机抽&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;概率 0.28 的词有 28% 的机会被选到, 概率 0.01 的词有 1% 的机会. 这就是&lt;strong&gt;纯采样 (Pure Sampling)&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;$$
\hat{w}&lt;em&gt;t \sim P(w | w&lt;/em&gt;{&amp;lt;t})
$$&lt;/p&gt;
&lt;p&gt;采样的好处是显而易见的: 每次生成的文本都不一样, 充满了多样性. 但它的问题是: &lt;strong&gt;低概率的词有时候真的太离谱了&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;语言模型的概率分布往往有很长的尾巴——成千上万个概率接近于 0 的词, 累加起来可能占 10-20% 的总概率. 纯采样有一定的概率选中这些词, 生成&quot;我喜欢的食物是斑马&quot;这种荒谬的结果.&lt;/p&gt;
&lt;p&gt;所以我们需要在&quot;确定性&quot;和&quot;多样性&quot;之间找一个折中——能够逐步调节的折中.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;3. Temperature: 一个可以&quot;拧&quot;的旋钮&lt;/h2&gt;
&lt;p&gt;在 softmax 之前加一个缩放参数 $T$, 让概率分布可以动态调整:&lt;/p&gt;
&lt;p&gt;$$
P_T(w_i) = \frac{e^{z_i / T}}{\sum_{j} e^{z_j / T}}
$$&lt;/p&gt;
&lt;p&gt;这里 $\mathbf{z} = (z_1, ..., z_V)$ 是模型输出的 logits (未归一化的分数), 不是最终概率. 这个 $T$ 就是 &lt;strong&gt;Temperature&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;当 $T$ 变化时, 分布会发生什么?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;我们来推导一下. 对于任意两个词 $i$ 和 $j$, 它们概率的比值是:&lt;/p&gt;
&lt;p&gt;$$
\frac{P_T(w_i)}{P_T(w_j)} = \frac{e^{z_i / T}}{e^{z_j / T}} = e^{(z_i - z_j) / T}
$$&lt;/p&gt;
&lt;p&gt;当 $T \to 0^+$:&lt;/p&gt;
&lt;p&gt;$$
\lim_{T \to 0^+} e^{(z_i - z_j) / T} = \begin{cases}
\infty &amp;amp; \text{if } z_i &amp;gt; z_j \
0 &amp;amp; \text{if } z_i &amp;lt; z_j \
1 &amp;amp; \text{if } z_i = z_j
\end{cases}
$$&lt;/p&gt;
&lt;p&gt;这意味着: 概率最高的词会被无限放大到 1, 其他词的概率都被压缩到 0——退化为 Greedy. 数学上, 这就是 &lt;strong&gt;argmax 的 soft 近似&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;当 $T = 1$: 保持原始分布不变.&lt;/p&gt;
&lt;p&gt;当 $T \to \infty$:&lt;/p&gt;
&lt;p&gt;$$
\lim_{T \to \infty} e^{(z_i - z_j) / T} = 1
$$&lt;/p&gt;
&lt;p&gt;所有词的概率趋于相等, 变成均匀分布——完全随机选择.&lt;/p&gt;
&lt;p&gt;所以 $T$ 就是一个从 &lt;strong&gt;完全确定&lt;/strong&gt; ($T \to 0$) 到 &lt;strong&gt;完全随机&lt;/strong&gt; ($T \to \infty$) 的连续旋钮.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;实际操作中:&lt;/strong&gt;&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;$T$&lt;/th&gt;
&lt;th&gt;效果&lt;/th&gt;
&lt;th&gt;解释&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;$T \to 0$&lt;/td&gt;
&lt;td&gt;退化为 Greedy&lt;/td&gt;
&lt;td&gt;概率最高的词被无限放大&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;$0.5$&lt;/td&gt;
&lt;td&gt;非常保守&lt;/td&gt;
&lt;td&gt;让高概率词更突出, 减少多样性&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;$0.7$&lt;/td&gt;
&lt;td&gt;略保守&lt;/td&gt;
&lt;td&gt;对话常用, 保留一定多样性&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;$1.0$&lt;/td&gt;
&lt;td&gt;原始分布&lt;/td&gt;
&lt;td&gt;不做任何调整&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;$1.5$&lt;/td&gt;
&lt;td&gt;有创意&lt;/td&gt;
&lt;td&gt;低概率词有更多机会&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;$T \gg 1$&lt;/td&gt;
&lt;td&gt;接近均匀分布&lt;/td&gt;
&lt;td&gt;几乎完全随机&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;但注意, Temperature 改变的是&lt;strong&gt;分布的形状&lt;/strong&gt;, 它没有解决&quot;低概率尾巴&quot;的问题——即使 $T=0.7$, 那些概率极低的词仍然有可能被采样到. 这就是接下来 Top-K 和 Top-P 要解决的问题.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. Top-K: 裁剪&quot;尾巴&quot;&lt;/h2&gt;
&lt;p&gt;Top-K 的思路直截了当: 只保留概率最高的 $K$ 个词, 其他的概率置零, 然后重新归一化.&lt;/p&gt;
&lt;p&gt;$$
\text{candidates} = {w_{(1)}, w_{(2)}, ..., w_{(K)}}
$$&lt;/p&gt;
&lt;p&gt;其中 $w_{(i)}$ 是按概率降序排列后的第 $i$ 个词.&lt;/p&gt;
&lt;p&gt;重新归一化:&lt;/p&gt;
&lt;p&gt;$$
P&apos;(w | w_{&amp;lt;t}) = \frac{P(w | w_{&amp;lt;t})}{\sum_{w&apos; \in \text{candidates}} P(w&apos; | w_{&amp;lt;t})}
$$&lt;/p&gt;
&lt;p&gt;这相当于把候选集之外的所有词的概率重新分配给候选集内的词. 过程等价于: 先截断, 再缩放.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;K 怎么选?&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$K=1$: 退化为 Greedy&lt;/li&gt;
&lt;li&gt;$K=50$: GPT-2 默认值&lt;/li&gt;
&lt;li&gt;$K=200+$: 接近完整采样&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;Top-K 的问题: 固定 $K$ 不够灵活.&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;考虑两种极端的概率分布:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;情况 A (尖峰分布)&lt;/strong&gt;: 一个词概率 0.85, 另一个 0.10, 其余 5% 概率均匀分布在 1000 个词上. Top-K=50 会把一堆概率 0.00005 的词捞进来——基本上等于在随机词里选, 容易出问题.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;情况 B (平坦分布)&lt;/strong&gt;: 前 10 个词的概率分别是 0.15, 0.14, 0.13, ..., 到第 50 个词概率还有 0.02. Top-K=50 的候选集很合理. 但如果 Top-K=5, 就切掉了一大半合理的选项.&lt;/p&gt;
&lt;p&gt;这就是 Top-K 的&quot;硬伤&quot;: 用一个固定值去应对动态变化的分布, 难免有时过松、有时过紧.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;5. Top-P (Nucleus): 动态裁剪&lt;/h2&gt;
&lt;p&gt;Top-P 的改进: 不限定候选数量, 而限&lt;strong&gt;累积概率&lt;/strong&gt;.&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;把词按概率从高到低排序&lt;/li&gt;
&lt;li&gt;累加概率, 直到累加和 $\ge P$&lt;/li&gt;
&lt;li&gt;只保留这部分词&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;$$
\text{candidates} = {w_{(1)}, ..., w_{(k)}} \ \text{s.t.} \ \sum_{i=1}^{k} P(w_{(i)} | w_{&amp;lt;t}) \ge P
$$&lt;/p&gt;
&lt;p&gt;其中 $k$ 是满足条件的最小值.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Top-P 的好处是自适应:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;分布集中时&lt;/strong&gt;: 候选数量少 (前几个词就占了 $P$ 的概率), 只保留高质量的候选&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;分布分散时&lt;/strong&gt;: 候选数量多, 保留更多多样性&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;典型的 $P=0.9$ 或 $0.95$. $P=1.0$ 等价于完整采样.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Top-K 和 Top-P 的关系:&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;可以用一个类比来理解: Top-K 是&quot;一刀切&quot;, 身高低于 1.8m 的人不准上场; Top-P 是&quot;按比例挑&quot;, 先选最高的, 直到总人数达到要求. 后者更灵活.&lt;/p&gt;
&lt;p&gt;实际应用中, &lt;strong&gt;两者经常组合使用&lt;/strong&gt;: 先用 Top-K 砍掉极低概率的尾巴 (比如 $K=50$), 再用 Top-P 做动态调整 (比如 $P=0.9$). 这样做的好处是: 即使分布极端, Top-K 也能确保不会选到太离谱的词.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;6. Temperature + Top-K + Top-P: 统一框架&lt;/h2&gt;
&lt;p&gt;把三个方法组合起来, 形成一个完整的解码流程:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入: logits z (长度为 V 的向量)

1. Temperature 缩放: z&apos; = z / T
2. Softmax: p = softmax(z&apos;)
3. Top-K 截断: 只保留概率最高的 K 个词, 其余置 0
4. Top-P 截断: 从最高概率开始累加, 直到累积概率 ≥ P
5. 重新归一化: 对保留下来的词重算概率
6. 从最终分布中采样一个词

输出: 下一个 token
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;每一步都有明确的数学意义:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Step 1: 控制&quot;确定性 vs 多样性&quot;的程度&lt;/li&gt;
&lt;li&gt;Step 3: 防止低概率词的&quot;尾部风险&quot;&lt;/li&gt;
&lt;li&gt;Step 4: 自适应调整候选集大小&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;HuggingFace 中对应的参数:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(&quot;model-name&quot;)

output = model.generate(
    input_ids,
    do_sample=True,          # 启用采样模式
    temperature=0.7,         # Temperature 缩放
    top_k=50,                # Top-K 截断
    top_p=0.9,               # Top-P 截断
    max_new_tokens=512
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;注意: &lt;code&gt;do_sample=False&lt;/code&gt; (默认) 时直接走 Greedy, 上面的 Temperature/Top-K/Top-P 参数不会生效.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;7. Beam Search: 另一种思路&lt;/h2&gt;
&lt;p&gt;上面讲的所有方法, 都是&lt;strong&gt;单步决策&lt;/strong&gt;: 每步选一个词, 选了就定了. Beam Search 走的是另一条路: &lt;strong&gt;同时维护多条候选路径&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;7.1 算法流程&lt;/h3&gt;
&lt;p&gt;假设 Beam width $B = 2$:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;第一步&lt;/strong&gt;: 模型输出所有词的概率, 保留概率最高的 $B=2$ 个词作为两条路径的起点&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;第二步&lt;/strong&gt;: 对每条路径分别计算下一步的概率, 得到 $B \times V = 2V$ 个候选项&lt;/li&gt;
&lt;li&gt;从 $2V$ 个候选中保留全局概率最高的 $B=2$ 条完整路径&lt;/li&gt;
&lt;li&gt;重复直到满足结束条件&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;路径的&quot;分数&quot;是&lt;strong&gt;对数概率之和&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;$$
\text{score}(\text{path}) = \sum_{t=1}^{T} \log P(w_t | w_{&amp;lt;t})
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么要用对数概率, 而不是直接乘概率?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;假设一个路径长 $T=100$, 每个词的概率平均 0.1, 路径的总概率是:&lt;/p&gt;
&lt;p&gt;$$
\prod_{t=1}^{100} 0.1 = 10^{-100}
$$&lt;/p&gt;
&lt;p&gt;这是一个小到浮点数都表示不了的数字——&lt;strong&gt;数值下溢&lt;/strong&gt; (underflow).&lt;/p&gt;
&lt;p&gt;用对数就解决了:&lt;/p&gt;
&lt;p&gt;$$
\sum_{t=1}^{100} \log(0.1) = -230
$$&lt;/p&gt;
&lt;p&gt;或者说 $\log P$ 和 $P$ 是单调关系 ($P_1 &amp;lt; P_2 \iff \log P_1 &amp;lt; \log P_2$), 所以最大化对数概率等价于最大化原始概率, 同时数值稳定.&lt;/p&gt;
&lt;h3&gt;7.2 Beam width 的选取&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;$B=1$: 退化为 Greedy&lt;/li&gt;
&lt;li&gt;$B=4$: 常用值, 质量明显提升&lt;/li&gt;
&lt;li&gt;$B=10+$: 边际收益递减, 计算量线性增长 ($B$ 倍)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;为什么边际收益递减? 因为前几个候选项之间的分数差距通常很大 (好的路径分数远高于差的路径), 多保留几个基本不会被选到. 只有当前几条路径分数接近时, 增大 $B$ 才有意义.&lt;/p&gt;
&lt;h3&gt;7.3 什么时候用 Beam Search?&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;适合&lt;/strong&gt;: 翻译、摘要、代码生成等追求&lt;strong&gt;确定性最优解&lt;/strong&gt;的任务. 注意: Beam Search 是确定性算法, 同样的输入总是得到同样的输出.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;不适合&lt;/strong&gt;: 对话、故事创作等需要&lt;strong&gt;多样性&lt;/strong&gt;的场景. Beam Search 倾向于选&quot;最安全的&quot;路径, 生成的文本往往比较平淡、模板化.&lt;/p&gt;
&lt;p&gt;原因: Beam Search 每一步都在最大化全局概率, 而&quot;最安全&quot;的路径往往是&quot;概率最高的常见词&quot;的排列组合, 缺乏新意. 在这种任务中, 带采样的 Top-K + Top-P 效果更好.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;8. 从 Softmax 再看一遍 Temperature&lt;/h2&gt;
&lt;p&gt;前面我们直接给了 Temperature 的公式. 现在从 softmax 的梯度视角再看一次, 可能会对 $T$ 的作用有更深的理解.&lt;/p&gt;
&lt;p&gt;Softmax 函数:&lt;/p&gt;
&lt;p&gt;$$
p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}
$$&lt;/p&gt;
&lt;p&gt;加入 Temperature 后:&lt;/p&gt;
&lt;p&gt;$$
p_i(T) = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}}
$$&lt;/p&gt;
&lt;p&gt;求 $p_i$ 对 $z_i$ 的梯度:&lt;/p&gt;
&lt;p&gt;$$
\frac{\partial p_i}{\partial z_i} = \frac{1}{T} \cdot p_i(1 - p_i)
$$&lt;/p&gt;
&lt;p&gt;当 $T$ 很小时, 梯度 $\frac{1}{T}$ 很大 —— 概率分布对 logits 的微小变化极其敏感, 分布趋于&quot;one-hot&quot; (一个 1, 其余 0). 当 $T$ 很大时, 梯度很小 —— 分布对 logits 不敏感, 趋于均匀.&lt;/p&gt;
&lt;p&gt;所以 Temperature 本质上是在控制&lt;strong&gt;概率分布对模型输出的敏感度&lt;/strong&gt;. 低 $T$ → 模型&amp;lt;u&amp;gt;高度自信&amp;lt;/u&amp;gt; → 输出确定; 高 $T$ → 模型&amp;lt;u&amp;gt;犹豫不决&amp;lt;/u&amp;gt; → 输出随机.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;9. 实践指南&lt;/h2&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;场景&lt;/th&gt;
&lt;th&gt;策略&lt;/th&gt;
&lt;th&gt;参数&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;机器翻译&lt;/td&gt;
&lt;td&gt;Beam Search&lt;/td&gt;
&lt;td&gt;$B=4$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;代码生成&lt;/td&gt;
&lt;td&gt;Low T + Top-P&lt;/td&gt;
&lt;td&gt;$T=0.1, P=0.9$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;对话&lt;/td&gt;
&lt;td&gt;Medium T + Top-K + Top-P&lt;/td&gt;
&lt;td&gt;$T=0.7, K=50, P=0.9$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;故事创作&lt;/td&gt;
&lt;td&gt;High T + Top-P&lt;/td&gt;
&lt;td&gt;$T=1.2, P=0.95$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;事实验证/科学&lt;/td&gt;
&lt;td&gt;Low T + Top-P&lt;/td&gt;
&lt;td&gt;$T=0.3, P=0.85$&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;理解这些策略的关键, 不是记住参数值, 而是理解每一条: &lt;strong&gt;我在确定性和多样性之间, 选择了哪一边?&lt;/strong&gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;10. 总结&lt;/h2&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;策略&lt;/th&gt;
&lt;th&gt;核心思想&lt;/th&gt;
&lt;th&gt;确定 vs 多样&lt;/th&gt;
&lt;th&gt;主要问题&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Greedy&lt;/td&gt;
&lt;td&gt;选概率最高的&lt;/td&gt;
&lt;td&gt;确定&lt;/td&gt;
&lt;td&gt;短视、重复&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Temperature&lt;/td&gt;
&lt;td&gt;缩放 logits 调节锐度&lt;/td&gt;
&lt;td&gt;可调节&lt;/td&gt;
&lt;td&gt;不解决尾部风险&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Top-K&lt;/td&gt;
&lt;td&gt;保留前 $K$ 个&lt;/td&gt;
&lt;td&gt;偏向确定&lt;/td&gt;
&lt;td&gt;固定值不够灵活&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Top-P&lt;/td&gt;
&lt;td&gt;保留累积概率 $P$&lt;/td&gt;
&lt;td&gt;自适应&lt;/td&gt;
&lt;td&gt;候选数量波动大&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Beam Search&lt;/td&gt;
&lt;td&gt;维护 $B$ 条路径&lt;/td&gt;
&lt;td&gt;确定&lt;/td&gt;
&lt;td&gt;计算量大, 平淡&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;&lt;strong&gt;终极建议&lt;/strong&gt;: 大部分场景下, 先用 Temperature 调到一个你想要的&quot;创意程度&quot;, 然后用 Top-K 砍掉最离谱的尾巴, 再用 Top-P 做自适应精调. 这三个组合起来, 可以覆盖大部分需求.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;Holtzman et al., The Curious Case of Neural Text Degeneration. ICLR 2020. &lt;a href=&quot;https://arxiv.org/abs/1904.09751&quot;&gt;arXiv:1904.09751&lt;/a&gt; — &lt;strong&gt;提出 Top-P (Nucleus) Sampling&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Fan et al., Hierarchical Neural Story Generation. ACL 2018. &lt;a href=&quot;https://arxiv.org/abs/1805.04833&quot;&gt;arXiv:1805.04833&lt;/a&gt; — &lt;strong&gt;提出 Top-K Sampling&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Graves, Sequence Transduction with Recurrent Neural Networks. 2012. &lt;a href=&quot;https://arxiv.org/abs/1211.3711&quot;&gt;arXiv:1211.3711&lt;/a&gt; — &lt;strong&gt;Beam Search 的经典应用&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Hinton, G. (2015). Distilling the Knowledge in a Neural Network — &lt;strong&gt;Temperature 在蒸馏中的应用, 和本文用到的是同一个概念&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Vaswani et al., Attention Is All You Need. NeurIPS 2017. &lt;a href=&quot;https://arxiv.org/abs/1706.03762&quot;&gt;arXiv:1706.03762&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;HuggingFace Docs: &lt;a href=&quot;https://huggingface.co/docs/transformers/en/generation_strategies&quot;&gt;Generation Strategies&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>大模型推理显存拆解 — 一步步算清你的显存去哪了</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-memory-breakdown/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-memory-breakdown/</guid><description>以 Llama-3-8B 为例，从参数怎么算、KV Cache 公式怎么来的、激活值有多大，到每项怎么优化，一步步推导而不是直接扔给你一个数字。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;你有一张 RTX 4090, 24GB 显存. 你想跑 Llama-3-8B.&lt;/p&gt;
&lt;p&gt;问题来了: 8B 参数的模型, 在 BF16 精度下, 光是加载权重就要 16GB. 你还有 8GB 的余量. &quot;够了! 跑吧!&quot;&lt;/p&gt;
&lt;p&gt;然后跑起来发现 OOM (Out Of Memory).&lt;/p&gt;
&lt;p&gt;为什么? 因为权重只是冰山一角. 推理过程中还有&lt;strong&gt;KV Cache&lt;/strong&gt;和&lt;strong&gt;激活值&lt;/strong&gt;在悄悄吃显存. 这篇文章我们就来一笔笔算清楚, 每个公式都从原理出发推导出来, 而不是直接扔给你一个数字.&lt;/p&gt;
&lt;p&gt;最终你会发现: 显存的去向其实完全可以精确计算, 而且每一笔都有优化的办法.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 模型参数: 第一笔账&lt;/h2&gt;
&lt;p&gt;这比账最好算.&lt;/p&gt;
&lt;h3&gt;1.1 基本公式&lt;/h3&gt;
&lt;p&gt;模型参数占用的显存 = 参数数量 × 每个参数的字节数:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{params}} = N_{\text{params}} \times b_{\text{param}}
$$&lt;/p&gt;
&lt;p&gt;其中 $b_{\text{param}}$ 取决于精度:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;FP32: 4 字节&lt;/li&gt;
&lt;li&gt;BF16 / FP16: 2 字节&lt;/li&gt;
&lt;li&gt;INT8: 1 字节&lt;/li&gt;
&lt;li&gt;INT4: 0.5 字节&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;所以对于 8B 参数的模型:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;精度&lt;/th&gt;
&lt;th&gt;$M_{\text{params}}$&lt;/th&gt;
&lt;th&gt;计算过程&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;BF16&lt;/td&gt;
&lt;td&gt;16 GB&lt;/td&gt;
&lt;td&gt;$8 \times 10^9 \times 2 = 16 \times 10^9$ 字节&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;INT8&lt;/td&gt;
&lt;td&gt;8 GB&lt;/td&gt;
&lt;td&gt;$8 \times 10^9 \times 1 = 8 \times 10^9$ 字节&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;INT4&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;$8 \times 10^9 \times 0.5 = 4 \times 10^9$ 字节&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;h3&gt;1.2 更精确地算: 参数从哪来?&lt;/h3&gt;
&lt;p&gt;&quot;8B 参数&quot;这个数字到底是怎么组成的? 我们以 Llama-3-8B 为例拆一下.&lt;/p&gt;
&lt;p&gt;Llama-3-8B 的结构参数:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;参数&lt;/th&gt;
&lt;th&gt;符号&lt;/th&gt;
&lt;th&gt;值&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;层数&lt;/td&gt;
&lt;td&gt;$L$&lt;/td&gt;
&lt;td&gt;32&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;隐藏维度&lt;/td&gt;
&lt;td&gt;$d_{\text{model}}$&lt;/td&gt;
&lt;td&gt;4096&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;FFN 中间维度&lt;/td&gt;
&lt;td&gt;$d_{\text{ff}}$&lt;/td&gt;
&lt;td&gt;14336&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;注意力头数&lt;/td&gt;
&lt;td&gt;$n_{\text{heads}}$&lt;/td&gt;
&lt;td&gt;32&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;KV 头数&lt;/td&gt;
&lt;td&gt;$n_{\text{kv}}$&lt;/td&gt;
&lt;td&gt;8 (GQA)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;词表大小&lt;/td&gt;
&lt;td&gt;$V$&lt;/td&gt;
&lt;td&gt;128000&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;逐层细分:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;1. Embedding 层&lt;/strong&gt; (词嵌入):&lt;/p&gt;
&lt;p&gt;$$
M_{\text{embed}} = V \times d_{\text{model}} = 128000 \times 4096 \approx 524\text{M}
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;2. 每层 Transformer&lt;/strong&gt; (共 32 层):&lt;/p&gt;
&lt;p&gt;注意力部分:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Q 投影: $d_{\text{model}} \times (n_{\text{heads}} \times d_{\text{head}}) = 4096 \times 4096 = 16.8\text{M}$&lt;/li&gt;
&lt;li&gt;K 投影: $d_{\text{model}} \times (n_{\text{kv}} \times d_{\text{head}}) = 4096 \times 1024 = 4.2\text{M}$&lt;/li&gt;
&lt;li&gt;V 投影: 同 K, $4.2\text{M}$&lt;/li&gt;
&lt;li&gt;O 投影: $(n_{\text{heads}} \times d_{\text{head}}) \times d_{\text{model}} = 4096 \times 4096 = 16.8\text{M}$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;每层注意力合计: $16.8 + 4.2 + 4.2 + 16.8 = 42\text{M}$&lt;/p&gt;
&lt;p&gt;FFN 部分 (SwiGLU, 3个矩阵):&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;gate_proj: $d_{\text{model}} \times d_{\text{ff}} = 4096 \times 14336 = 58.7\text{M}$&lt;/li&gt;
&lt;li&gt;up_proj: 同上, $58.7\text{M}$&lt;/li&gt;
&lt;li&gt;down_proj: $d_{\text{ff}} \times d_{\text{model}} = 14336 \times 4096 = 58.7\text{M}$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;每层 FFN 合计: $58.7 \times 3 = 176.1\text{M}$&lt;/p&gt;
&lt;p&gt;每层 Transformer 合计: $42 + 176.1 = 218.1\text{M}$
32 层: $218.1 \times 32 \approx 7.0\text{B}$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;3. RMS Norm&lt;/strong&gt; (每层有 2 个, 加上最后的):&lt;/p&gt;
&lt;p&gt;每个 RMS Norm 只有 $d_{\text{model}}$ 个可训练参数 (= 4096), 可以忽略.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;4. LM Head&lt;/strong&gt; (输出层):&lt;/p&gt;
&lt;p&gt;$d_{\text{model}} \times V = 4096 \times 128000 = 524\text{M}$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;总计&lt;/strong&gt;: $524\text{M} (\text{embed}) + 7.0\text{B} (\text{32层}) + 524\text{M} (\text{head}) \approx 8.0\text{B}$ ✓&lt;/p&gt;
&lt;p&gt;这个计算验证了: 8B 参数不是凭空说的, 每一层、每个矩阵的贡献都可以精确计算. 当有人告诉你&quot;这是一个 8B 模型&quot;时, 你可以快速心算: 大概要占 16GB (BF16) / 8GB (INT8) / 4GB (INT4).&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;2. KV Cache: 被严重低估的显存杀手&lt;/h2&gt;
&lt;blockquote&gt;
&lt;p&gt;KV Cache 的原理我在之前的博客 &lt;a href=&quot;/posts/2025-02-06-A-Series-on-LLM-Inference-II/&quot;&gt;A Series on LLMs (II)&lt;/a&gt; 中已经详细介绍过了, 这里简单回顾一下核心思路, 重点放在&quot;占多少显存&quot;的计算上.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;2.1 从注意力公式出发&lt;/h3&gt;
&lt;p&gt;先回顾一下 Transformer Decoder 的注意力计算. 在第 $t$ 步, 模型需要计算:&lt;/p&gt;
&lt;p&gt;$$
\text{Attention}(Q_t, K_{\le t}, V_{\le t}) = \text{softmax}\left(\frac{Q_t K_{\le t}^T}{\sqrt{d_k}}\right) V_{\le t}
$$&lt;/p&gt;
&lt;p&gt;这里 $Q_t$ 是&lt;strong&gt;当前 token&lt;/strong&gt;的 query (大小 $1 \times d_k$), 而 $K_{\le t}$ 和 $V_{\le t}$ 是&lt;strong&gt;所有历史位置&lt;/strong&gt;的 key 和 value (大小 $t \times d_k$).&lt;/p&gt;
&lt;p&gt;你可以选择:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;方案 A&lt;/strong&gt;: 每次重新算 $K_{\le t}$ 和 $V_{\le t}$ — 第 $t$ 步的计算量是 $O(t \times d_k)$, 累计 $O(T^2 \times d_k)$, 序列长了完全不可接受.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;方案 B&lt;/strong&gt;: 把之前每一步算好的 $K_i$, $V_i$ 存起来, 每次只要算当前 token 的 $K_t$, $V_t$, 然后拼到缓存里.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;方案 B 就是 &lt;strong&gt;KV Cache&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;2.2 KV Cache 的精确公式&lt;/h3&gt;
&lt;p&gt;每层需要缓存 K 和 V 两份. 每个 token 每层每头需要的空间是 $d_{\text{head}} \times b_{\text{param}}$.&lt;/p&gt;
&lt;p&gt;所以 KV Cache 的总大小:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{kv}} = 2 \times L \times n_{\text{kv}} \times d_{\text{head}} \times b_{\text{param}} \times T
$$&lt;/p&gt;
&lt;p&gt;其中 $T$ 是序列长度.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么要乘以 $n_{\text{kv}}$ 而不是 $n_{\text{heads}}$?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这里取决于注意力机制:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;MHA&lt;/strong&gt; (Multi-Head Attention): 每个 query head 有独立的 K, V head → $n_{\text{kv}} = n_{\text{heads}}$&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;GQA&lt;/strong&gt; (Grouped-Query Attention): 多个 query head 共享一组 K, V → $n_{\text{kv}} &amp;lt; n_{\text{heads}}$&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;MQA&lt;/strong&gt; (Multi-Query Attention): 所有 query head 共享同一组 K, V → $n_{\text{kv}} = 1$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Llama-3-8B 用了 GQA, $n_{\text{kv}} = 8$. 如果它用 MHA ($n_{\text{kv}} = 32$), KV Cache 会大 4 倍!&lt;/p&gt;
&lt;h3&gt;2.3 具体数字&lt;/h3&gt;
&lt;p&gt;以 Llama-3-8B 为例 ($L=32$, $n_{\text{kv}}=8$, $d_{\text{head}}=128$, BF16, $b_{\text{param}}=2$):&lt;/p&gt;
&lt;p&gt;$$
M_{\text{kv}} = 2 \times 32 \times 8 \times 128 \times 2 \times T
$$&lt;/p&gt;
&lt;p&gt;简化: $M_{\text{kv}} = 131072 \times T \ \text{字节} = 128 \times T \ \text{KB}$&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;$T$&lt;/th&gt;
&lt;th&gt;$M_{\text{kv}}$&lt;/th&gt;
&lt;th&gt;占权重的比例&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;512&lt;/td&gt;
&lt;td&gt;64 MB&lt;/td&gt;
&lt;td&gt;0.4%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;2,048&lt;/td&gt;
&lt;td&gt;256 MB&lt;/td&gt;
&lt;td&gt;1.6%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;4,096&lt;/td&gt;
&lt;td&gt;512 MB&lt;/td&gt;
&lt;td&gt;3.1%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;8,192&lt;/td&gt;
&lt;td&gt;1 GB&lt;/td&gt;
&lt;td&gt;6.3%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;32,768&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;25%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;128,000&lt;/td&gt;
&lt;td&gt;16 GB&lt;/td&gt;
&lt;td&gt;100%&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;可以看到: &lt;strong&gt;当序列长度达到 128K 时, KV Cache 的显存开销已经和模型权重本身一样大!&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这就是为什么长上下文推理如此吃显存. 跑 128K 上下文意味着你需要&lt;strong&gt;双倍&lt;/strong&gt;的显存——一份装权重, 一份装 KV Cache.&lt;/p&gt;
&lt;h3&gt;2.4 与 batch size 的关系&lt;/h3&gt;
&lt;p&gt;上面的计算假设 batch size = 1. 如果同时处理 $B$ 个请求:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{kv}}(\text{total}) = M_{\text{kv}}(T) \times B
$$&lt;/p&gt;
&lt;p&gt;KV Cache 随着 batch size &lt;strong&gt;线性增长&lt;/strong&gt;. 如果有 8 个并发请求, 每个 32K 上下文, KV Cache 就要 32GB——已经超过了大多数消费级显卡.&lt;/p&gt;
&lt;p&gt;这就是为什么 vLLM 等推理框架如此重要——它们通过 PagedAttention 让多个请求共用显存, 消除了内部碎片.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;3. 激活值: 临时工&lt;/h2&gt;
&lt;p&gt;与模型参数和 KV Cache 不同, 激活值是&lt;strong&gt;临时&lt;/strong&gt;占用的——每步前向传播后就会被释放.&lt;/p&gt;
&lt;h3&gt;3.1 激活值从哪来?&lt;/h3&gt;
&lt;p&gt;在推理的每一步, 数据流过每一层 Transformer:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入 (hidden_states)
  → RMS Norm → QKV 投影 → Attention 计算 → 残差连接
  → RMS Norm → FFN (gate/up/down) → 残差连接
  → 输出到下一层
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;每一层的&lt;strong&gt;中间结果&lt;/strong&gt;都需要占用显存. 具体来说:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Attention 部分: Q, K, V 投影后的矩阵, attention score ($T \times T$), attention output&lt;/li&gt;
&lt;li&gt;FFN 部分: gate 输出, up 输出, 中间激活, down 输出&lt;/li&gt;
&lt;li&gt;残差连接: 需要保留输入向量用于加法&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3.2 估算公式&lt;/h3&gt;
&lt;p&gt;有个经验公式可以快速估算:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{act}} \approx ( 34 \times d_{\text{model}} + 5 \times d_{\text{ff}} ) \times T \times B \times b_{\text{param}}
$$&lt;/p&gt;
&lt;p&gt;这个 34 和 5 是怎么来的? 来自每层中各种中间矩阵的大小之和. 对于 Llama-3-8B ($d_{\text{model}}=4096$, $d_{\text{ff}}=14336$):&lt;/p&gt;
&lt;p&gt;$$
M_{\text{act}} \approx (34 \times 4096 + 5 \times 14336) \times T \times B \times 2
$$&lt;/p&gt;
&lt;p&gt;当 $B=1$ 时:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;$T$&lt;/th&gt;
&lt;th&gt;$M_{\text{act}}$&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;512&lt;/td&gt;
&lt;td&gt;$\approx 206\text{MB}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;4,096&lt;/td&gt;
&lt;td&gt;$\approx 1.6\text{GB}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;32,768&lt;/td&gt;
&lt;td&gt;$\approx 13\text{GB}$&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;注意: 长序列时激活值的占用也接近模型权重了! 这是因为 $T \times d_{\text{model}}$ 的乘积在变大.&lt;/p&gt;
&lt;h3&gt;3.3 为什么激活值容易被忽略&lt;/h3&gt;
&lt;p&gt;KV Cache 和模型参数是&lt;strong&gt;常驻&lt;/strong&gt;显存的——加载后直到推理结束才释放. 激活值是&lt;strong&gt;临时&lt;/strong&gt;的——每算完一层就释放一部分.&lt;/p&gt;
&lt;p&gt;所以很多人只关注常驻部分. 但问题在于&lt;strong&gt;峰值&lt;/strong&gt;时刻: 当长序列且没有 Flash Attention 时, 完整的 $T \times T$ 注意力矩阵 (对 32K 序列就是 $32K \times 32K \times 2\text{bytes} \approx 2\text{GB}$) 可能瞬间撑爆显存.&lt;/p&gt;
&lt;p&gt;这就是为什么 Flash Attention 如此重要——它通过分块计算避免了一次性创建完整的注意力矩阵.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. 总账本&lt;/h2&gt;
&lt;p&gt;对 Llama-3-8B (BF16, batch=1) 做个总账:&lt;/p&gt;
&lt;h3&gt;4.1 短序列 (512 tokens)&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;项目&lt;/th&gt;
&lt;th&gt;大小&lt;/th&gt;
&lt;th&gt;占比&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;模型参数&lt;/td&gt;
&lt;td&gt;16.0 GB&lt;/td&gt;
&lt;td&gt;96.5%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;KV Cache&lt;/td&gt;
&lt;td&gt;0.064 GB&lt;/td&gt;
&lt;td&gt;0.4%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;激活值 (峰值)&lt;/td&gt;
&lt;td&gt;0.2 GB&lt;/td&gt;
&lt;td&gt;1.2%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;总计&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;~16.3 GB&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;→ 24GB 显卡轻松跑. 主要瓶颈是模型权重.&lt;/p&gt;
&lt;h3&gt;4.2 中等序列 (8K tokens)&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;项目&lt;/th&gt;
&lt;th&gt;大小&lt;/th&gt;
&lt;th&gt;占比&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;模型参数&lt;/td&gt;
&lt;td&gt;16.0 GB&lt;/td&gt;
&lt;td&gt;74%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;KV Cache&lt;/td&gt;
&lt;td&gt;1.0 GB&lt;/td&gt;
&lt;td&gt;5%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;激活值 (峰值)&lt;/td&gt;
&lt;td&gt;3.2 GB&lt;/td&gt;
&lt;td&gt;15%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;其他开销&lt;/td&gt;
&lt;td&gt;~1.3 GB&lt;/td&gt;
&lt;td&gt;6%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;总计&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;~21.5 GB&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;→ 24GB 显卡刚好够用, 但快满了.&lt;/p&gt;
&lt;h3&gt;4.3 长序列 (32K tokens)&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;项目&lt;/th&gt;
&lt;th&gt;大小&lt;/th&gt;
&lt;th&gt;占比&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;模型参数&lt;/td&gt;
&lt;td&gt;16.0 GB&lt;/td&gt;
&lt;td&gt;44%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;KV Cache&lt;/td&gt;
&lt;td&gt;4.0 GB&lt;/td&gt;
&lt;td&gt;11%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;激活值 (峰值)&lt;/td&gt;
&lt;td&gt;13 GB&lt;/td&gt;
&lt;td&gt;36%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;其他开销&lt;/td&gt;
&lt;td&gt;~3 GB&lt;/td&gt;
&lt;td&gt;9%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;总计&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;~36 GB&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;→ 24GB 显卡完全不够! 必须优化.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;5. 每项都能优化&lt;/h2&gt;
&lt;p&gt;既然知道了每一笔账, 就可以针对性地&quot;省钱&quot;.&lt;/p&gt;
&lt;h3&gt;5.1 模型参数: 量化&lt;/h3&gt;
&lt;p&gt;量化就是把 BF16 降到更低位宽:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{params}}(\text{INT4}) = \frac{1}{4} M_{\text{params}}(\text{BF16})
$$&lt;/p&gt;
&lt;p&gt;8B 模型: 16 GB → 4 GB, 省 12GB.&lt;/p&gt;
&lt;p&gt;代价? 理论上少量精度损失, 实践中 INT4 的 MMLU 损失通常在 1% 以内. 值不值? 对于部署来说, 太值了.&lt;/p&gt;
&lt;h3&gt;5.2 KV Cache: 三个方向&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;方向 1: GQA / MQA&lt;/strong&gt; (架构层面)&lt;/p&gt;
&lt;p&gt;从 MHA 换成 GQA 或 MQA, 直接减少 $n_{\text{kv}}$:&lt;/p&gt;
&lt;p&gt;$$
\frac{M_{\text{kv}}^{\text{GQA}}}{M_{\text{kv}}^{\text{MHA}}} = \frac{n_{\text{kv}}^{\text{GQA}}}{n_{\text{kv}}^{\text{MHA}}}
$$&lt;/p&gt;
&lt;p&gt;Llama-3-8B 用 GQA ($n_{\text{kv}}=8$) 而非 MHA ($n_{\text{kv}}=32$), KV Cache 直接省 4 倍.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方向 2: KV Cache 量化&lt;/strong&gt; (数值层面)&lt;/p&gt;
&lt;p&gt;把 KV Cache 从 FP16 (2 字节) 存成 INT8 (1 字节):&lt;/p&gt;
&lt;p&gt;$$
M_{\text{kv}}(\text{INT8}) = \frac{1}{2} M_{\text{kv}}(\text{FP16})
$$&lt;/p&gt;
&lt;p&gt;32K 上下文: 4 GB → 2 GB.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方向 3: PagedAttention&lt;/strong&gt; (系统层面)&lt;/p&gt;
&lt;p&gt;KV Cache 按固定大小的 page 分配, 类似操作系统虚拟内存分页. 主要收益:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;消除内部碎片 (不同序列长度导致的不连续分配)&lt;/li&gt;
&lt;li&gt;方便内存共享 (如 beam search 的多个候选共用前缀)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;vLLM 用 PagedAttention 宣称能省 60-80% 的 KV Cache 显存——这个数字来自&lt;strong&gt;碎片消除&lt;/strong&gt; + &lt;strong&gt;共享前缀&lt;/strong&gt; + &lt;strong&gt;按需分配&lt;/strong&gt;的综合效果.&lt;/p&gt;
&lt;h3&gt;5.3 激活值: Flash Attention&lt;/h3&gt;
&lt;p&gt;标准的注意力实现需要构建 $T \times T$ 的注意力矩阵:&lt;/p&gt;
&lt;p&gt;$$
M_{\text{attn}} = T^2 \times b_{\text{param}}
$$&lt;/p&gt;
&lt;p&gt;32K 序列: $32768^2 \times 2 \approx 2\text{GB}$&lt;/p&gt;
&lt;p&gt;Flash Attention 把计算分块, 让注意力矩阵的子块在 SRAM 中处理, 然后累加结果进 HBM. 这样在 HBM 层面只需要 $O(T \times d)$ 的显存, 而不是 $O(T^2)$.&lt;/p&gt;
&lt;p&gt;收益: 长序列时激活值显存从 $O(T^2)$ 变成 $O(T)$——对于 32K 序列, 可以省数十 GB.&lt;/p&gt;
&lt;h3&gt;5.4 组合优化的效果&lt;/h3&gt;
&lt;p&gt;对 Llama-3-8B, 32K 上下文, BF16→INT4, 加上各种优化:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;优化&lt;/th&gt;
&lt;th&gt;模型参数&lt;/th&gt;
&lt;th&gt;KV Cache&lt;/th&gt;
&lt;th&gt;激活值&lt;/th&gt;
&lt;th&gt;总计&lt;/th&gt;
&lt;th&gt;备注&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;原始 (BF16)&lt;/td&gt;
&lt;td&gt;16 GB&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;13 GB&lt;/td&gt;
&lt;td&gt;~36 GB&lt;/td&gt;
&lt;td&gt;不行&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;+ INT4 量化&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;13 GB&lt;/td&gt;
&lt;td&gt;~24 GB&lt;/td&gt;
&lt;td&gt;勉强&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;+ KV Cache INT8&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;2 GB&lt;/td&gt;
&lt;td&gt;13 GB&lt;/td&gt;
&lt;td&gt;~22 GB&lt;/td&gt;
&lt;td&gt;OK&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;+ Flash Attention&lt;/td&gt;
&lt;td&gt;4 GB&lt;/td&gt;
&lt;td&gt;2 GB&lt;/td&gt;
&lt;td&gt;~1 GB&lt;/td&gt;
&lt;td&gt;~7 GB&lt;/td&gt;
&lt;td&gt;轻松&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;这就是 LLM 推理优化的魔力——&lt;strong&gt;通过理解每一笔开销的数学原理, 你可以有针对性地节省几十 GB 的显存&lt;/strong&gt;.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;6. 实际部署经验&lt;/h2&gt;
&lt;p&gt;公式都推导清楚了, 实际操作就简单了:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;短上下文 (&amp;lt;2K)&lt;/strong&gt;: 主要瓶颈是模型参数 → 先量化
&lt;strong&gt;长上下文 (&amp;gt;8K)&lt;/strong&gt;: 主要瓶颈是 KV Cache → GQA + PagedAttention
&lt;strong&gt;大批量推理&lt;/strong&gt;: 激活值和 KV Cache 都线性增长 → Flash Attention + PagedAttention&lt;/p&gt;
&lt;p&gt;一个具体的配置:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# vLLM 自动处理大部分优化
from vllm import LLM, SamplingParams

llm = LLM(
    model=&quot;meta-llama/Meta-Llama-3-8B&quot;,
    max_model_len=8192,         # 限制最大序列长度
    gpu_memory_utilization=0.9, # 使用 90% 显存
    kv_cache_dtype=&quot;fp8&quot;,       # KV Cache 量化
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;HuggingFace 原生:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;model = AutoModelForCausalLM.from_pretrained(
    &quot;meta-llama/Meta-Llama-3-8B&quot;,
    torch_dtype=torch.bfloat16,
    device_map=&quot;auto&quot;,
    attn_implementation=&quot;flash_attention_2&quot;,  # 省激活值
)
&lt;/code&gt;&lt;/pre&gt;
&lt;hr /&gt;
&lt;h2&gt;7. 总结&lt;/h2&gt;
&lt;p&gt;推理显存的数学很简单, 就是加乘:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;项目&lt;/th&gt;
&lt;th&gt;公式&lt;/th&gt;
&lt;th&gt;关键参数&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;模型参数&lt;/td&gt;
&lt;td&gt;$N_{\text{params}} \times b_{\text{param}}$&lt;/td&gt;
&lt;td&gt;参数量和精度&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;KV Cache&lt;/td&gt;
&lt;td&gt;$2 \times L \times n_{\text{kv}} \times d_{\text{head}} \times b_{\text{param}} \times T$&lt;/td&gt;
&lt;td&gt;层数、头数、序列长度&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;激活值&lt;/td&gt;
&lt;td&gt;$\approx 34 \times d_{\text{model}} \times T \times B \times b_{\text{param}}$&lt;/td&gt;
&lt;td&gt;模型宽度、序列长度&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;关键在于: &lt;strong&gt;每一项你都能精确算出, 每算出来一项, 就知道应该从哪下手优化&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;下次别人说&quot;这个 8B 模型跑不起来&quot;, 你可以问: 上下文多长? 精度用什么? 用 Flash Attention 了吗? ——而且每一问你都知道他差在哪里.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;Kwon et al., Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. &lt;a href=&quot;https://arxiv.org/abs/2309.06180&quot;&gt;arXiv:2309.06180&lt;/a&gt; — &lt;strong&gt;vLLM 核心论文&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention. NeurIPS 2022. &lt;a href=&quot;https://arxiv.org/abs/2205.14135&quot;&gt;arXiv:2205.14135&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Shazeer, Fast Transformer Decoding: One Write-Head is All You Need. 2019. &lt;a href=&quot;https://arxiv.org/abs/1911.02150&quot;&gt;arXiv:1911.02150&lt;/a&gt; — &lt;strong&gt;MQA&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models. 2023. &lt;a href=&quot;https://arxiv.org/abs/2305.13245&quot;&gt;arXiv:2305.13245&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Meta, The Llama 3 Herd of Models. 2024. &lt;a href=&quot;https://arxiv.org/abs/2407.21783&quot;&gt;arXiv:2407.21783&lt;/a&gt; — &lt;strong&gt;Llama-3 架构细节&lt;/strong&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>长上下文扩展 — 从 RoPE 出发，一步步推导 PI、NTK 到 YaRN</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-long-context-yarn/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-long-context-yarn/</guid><description>从 RoPE 的 θ 公式出发，先想清楚&quot;为什么 RoPE 在训练长度外效果差&quot;，再一步步推出 PI、NTK-aware、YaRN 的改进思路和数学原理。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;上一篇文章我们推导了 &lt;a href=&quot;/posts/2026-05-16-llm-rope-rotary-position-embedding/&quot;&gt;RoPE: 旋转位置编码&lt;/a&gt;: 用旋转矩阵给每个位置编码, 让 attention 的内积只依赖相对位置.&lt;/p&gt;
&lt;p&gt;但 RoPE 有一个棘手的问题: 模型在训练时只见过 $[0, L_{\text{train}})$ 范围内的位置. 推理时突然要处理 $m \gg L_{\text{train}}$ 的位置——即使 RoPE 的公式在数学上可以计算任意 $m$, &lt;strong&gt;模型&quot;没见过&quot;这么大位置上的频率组合&lt;/strong&gt;, 效果会断崖式下跌.&lt;/p&gt;
&lt;p&gt;这篇文章就从 RoPE 的频率公式出发, 一步步推导各个改进方案: 从最朴素的 &lt;strong&gt;Position Interpolation&lt;/strong&gt;, 到 &lt;strong&gt;NTK-aware scaling&lt;/strong&gt;, 再到集大成者的 &lt;strong&gt;YaRN&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;最终你会理解: 这些方法不是在&quot;发明&quot;新东西, 而是在回答一个问题——&lt;strong&gt;当模型需要处理从未见过的长位置时, 怎么把已有的 RoPE 频率知识迁移过去?&lt;/strong&gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 先定位问题: RoPE 在长位置为什么不行&lt;/h2&gt;
&lt;p&gt;RoPE 中, 第 $i$ 个维度对的旋转频率是:&lt;/p&gt;
&lt;p&gt;$$
\theta_i = 10000^{-2i/d}, \quad i = 0, 1, ..., d/2 - 1
$$&lt;/p&gt;
&lt;p&gt;位置 $m$ 的旋转角度是 $m\theta_i$.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;训练时&lt;/strong&gt;: 模型只见过 $m \in [0, L_{\text{train}})$ 范围内的 $m\theta_i$ 值. 这些值覆盖了所有 $\theta_i$ 的某个范围.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;推理时&lt;/strong&gt;: 当 $m &amp;gt; L_{\text{train}}$, $m\theta_i$ 超出了训练时见过的范围. 尤其是对于高频维度 ($i$ 小, $\theta_i$ 大), 在 $L_{\text{train}}$ 内可能已经转了很多圈; 而对于低频维度 ($i$ 接近 $d/2$, $\theta_i$ 很小), 在 $L_{\text{train}}$ 内可能才转了不到半圈.&lt;/p&gt;
&lt;p&gt;模型在训练时学到的是一种&lt;strong&gt;频率组合的&quot;分布&quot;&lt;/strong&gt;——当输入第 $m$ 个 token 时, 各维度对以不同的旋转角度协同工作. 超出训练范围后, 这些角度组合不再符合训练时的分布, 模型就&quot;懵&quot;了.&lt;/p&gt;
&lt;p&gt;这个认识很重要——问题不在于 RoPE 的数学, 而在于&lt;strong&gt;分布外泛化&lt;/strong&gt;. 所以所有改进方案的核心都是: 如何把长位置的旋转角度&quot;拉回&quot;训练时的分布内, 同时尽量保留位置信息.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;2. 方案一: Position Interpolation (PI) — 简单但粗暴&lt;/h2&gt;
&lt;h3&gt;2.1 核心思路&lt;/h3&gt;
&lt;p&gt;PI 的想法非常直接: &lt;strong&gt;既然长位置的 $m\theta_i$ 没见过, 那就把它缩回训练时的范围内&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;做法: 把位置 $m$ 映射到 $m&apos; = m \times \frac{L_{\text{train}}}{L_{\text{infer}}}$.&lt;/p&gt;
&lt;p&gt;也就是说, 旋转角度从 $m\theta_i$ 变成:&lt;/p&gt;
&lt;p&gt;$$
m&apos;\theta_i = \frac{L_{\text{train}}}{L_{\text{infer}}} \cdot m \cdot \theta_i
$$&lt;/p&gt;
&lt;p&gt;例如训练 4K, 推理 32K: 位置 16,000 被当作位置 2,000 来计算旋转. 这意味着 16,000 位置的向量和训练时 2,000 位置的向量&lt;strong&gt;经历完全相同的旋转&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;好处: 所有 $m\theta_i$ 值都落在训练范围内, 模型不会&quot;没见过&quot;.&lt;/p&gt;
&lt;h3&gt;2.2 数学上分析 PI 的问题&lt;/h3&gt;
&lt;p&gt;用 $s = L_{\text{infer}} / L_{\text{train}}$ 表示扩展比 (scale). PI 的等效频率是:&lt;/p&gt;
&lt;p&gt;$$
\theta_i^{\text{PI}} = \frac{\theta_i}{s}
$$&lt;/p&gt;
&lt;p&gt;相邻位置的角度差从 $\theta_i$ 变成了 $\theta_i / s$.&lt;/p&gt;
&lt;p&gt;来算一下这会造成什么后果. 对于高频维度 ($i=0$), $\theta_0 = 1.0$, $d=128$:&lt;/p&gt;
&lt;p&gt;原始相邻位置差: $\theta_0 = 1.0$ 弧度, 约 $57.3^\circ$
PI 后相邻位置差 ($s=8$): $\theta_0 / 8 = 0.125$ 弧度, 约 $7.2^\circ$&lt;/p&gt;
&lt;p&gt;原来位置 $m$ 和 $m+1$ 的向量方向相差 $57^\circ$, 很容易区分. PI 后只差 $7^\circ$, 几乎重叠——&lt;strong&gt;高频分辨率严重下降&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;对于低频维度 ($i=63$), $\theta_{63} = 10000^{-126/128} \approx 0.0001$:&lt;/p&gt;
&lt;p&gt;原始相邻位置差: $\approx 0.0057^\circ$ — 本来相邻位置就很难区分
PI 后: $\approx 0.0007^\circ$ — 更分不清了&lt;/p&gt;
&lt;p&gt;但低频维度的作用本来就不是区分相邻位置, 而是感知&lt;strong&gt;大范围距离&lt;/strong&gt;. 所以低频损失一些分辨率问题不大. 真正致命的是高频分辨率丢失——它破坏了模型对精细位置关系的建模能力.&lt;/p&gt;
&lt;h3&gt;2.3 PI 的结论&lt;/h3&gt;
&lt;p&gt;PI 的效果其实还不错——经过几千步微调 (fine-tuning), 可以很好地扩展到 8 倍长度. 但它的问题也明显: &lt;strong&gt;高频信息被均匀压缩, 短距离的区分度下降&lt;/strong&gt;. 如果你既想在短序列上保持原有精度, 又想扩展到长序列, PI 不是最优选择.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;3. 方案二: NTK-aware — 保留高频分辨率&lt;/h2&gt;
&lt;h3&gt;3.1 直觉&lt;/h3&gt;
&lt;p&gt;NTK-aware 的直觉和 PI 相反: &lt;strong&gt;高频维度携带精细位置信息, 应该保持分辨率; 低频维度负责大范围感知, 可以拉伸&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;怎么实现? 不缩放位置 $m$, 而是调整频率 $\theta_i$ 本身. 核心是修改 RoPE 的 base 值.&lt;/p&gt;
&lt;h3&gt;3.2 从 base 修改出发&lt;/h3&gt;
&lt;p&gt;回顾 RoPE 的频率公式:&lt;/p&gt;
&lt;p&gt;$$
\theta_i = \text{base}^{-2i/d}
$$&lt;/p&gt;
&lt;p&gt;如果我们把 base 从 10000 换成 $\text{base}&apos; &amp;gt; \text{base}$, 会发生什么?&lt;/p&gt;
&lt;p&gt;对于高频 ($i$ 小): $\theta_i \approx \text{base}&apos;^{-2i/d}$ — 大的 base 使得 $\theta_i$ 变小(因为指数是负的), 高频频率降低.&lt;/p&gt;
&lt;p&gt;不对, 仔细算. $\theta_i = \text{base}^{-2i/d}$. 当 $i$ 很小时($2i/d \approx 0$), $\text{base}^{-2i/d} \approx 1$ 对 base 的变化不敏感. 当 $i$ 接近 $d/2$ 时 ($2i/d \approx 1$), $\theta_{d/2} = \text{base}^{-1}$, 增大 base 会显著降低低频频率.&lt;/p&gt;
&lt;p&gt;所以: &lt;strong&gt;增大 base, 高频几乎不变, 低频被压低&lt;/strong&gt;. 这正是我们想要的!&lt;/p&gt;
&lt;p&gt;NTK-aware 选择:&lt;/p&gt;
&lt;p&gt;$$
\text{base}&apos; = \text{base} \times \alpha
$$&lt;/p&gt;
&lt;p&gt;其中 $\alpha$ 是跟扩展比相关的值. 推荐的 $\alpha$ 选择:&lt;/p&gt;
&lt;p&gt;$$
\alpha = \left(\frac{L_{\text{infer}}}{L_{\text{train}}}\right)^{d/(d-2)}
$$&lt;/p&gt;
&lt;p&gt;这个公式的推导思路是: 让&lt;strong&gt;最低频维度&lt;/strong&gt;的波长远大于训练长度, 从而把长位置的旋转角度&quot;拉伸&quot;回训练时的范围. 推导如下:&lt;/p&gt;
&lt;p&gt;最低频维度 ($i = d/2 - 1$) 的原始波长:&lt;/p&gt;
&lt;p&gt;$$
\lambda_{\min} = \frac{2\pi}{\theta_{d/2-1}} = 2\pi \times \text{base}^{(d-2)/d}
$$&lt;/p&gt;
&lt;p&gt;我们希望 $\lambda_{\min}$ 拉伸到 $L_{\text{infer}}$ 量级, 所以选择 $\text{base}&apos;$ 使新波长:&lt;/p&gt;
&lt;p&gt;$$
\lambda_{\min}&apos; = L_{\text{infer}} \implies \text{base}&apos;^{(d-2)/d} \propto L_{\text{infer}}
$$&lt;/p&gt;
&lt;p&gt;解出 $\text{base}&apos; / \text{base} \propto (L_{\text{infer}} / L_{\text{train}})^{d/(d-2)}$.&lt;/p&gt;
&lt;p&gt;实际使用中, 可以简化为 $\text{base}&apos; = \text{base} \times s$ (其中 $s = L_{\text{infer}}/L_{\text{train}}$), 或者用经验值 $\text{base}&apos; = \text{base} \times s^{1.2}$ 之类的. 具体哪个最好需要实验验证.&lt;/p&gt;
&lt;h3&gt;3.3 NTK-aware 的逐频率视角&lt;/h3&gt;
&lt;p&gt;另一种等价的理解方式: NTK-aware 等价于逐频率使用不同的缩放因子.&lt;/p&gt;
&lt;p&gt;原始频率 $\theta_i$, NTK 后的频率 $\theta_i&apos;$:&lt;/p&gt;
&lt;p&gt;$$
\theta_i&apos; = \text{base}&apos;^{-2i/d} = (\text{base} \cdot \alpha)^{-2i/d} = \text{base}^{-2i/d} \cdot \alpha^{-2i/d}
$$&lt;/p&gt;
&lt;p&gt;所以 $\theta_i&apos; = \theta_i \cdot \alpha^{-2i/d}$.&lt;/p&gt;
&lt;p&gt;相比 PI 对所有频率乘 $1/s$, NTK-aware 的缩放因子 $\alpha^{-2i/d}$ 是&lt;strong&gt;频率相关的&lt;/strong&gt;: 高频 ($i$ 小) 缩放因子接近 1 (几乎不变), 低频 ($i$ 大) 缩放因子显著小于 1 (频率降低, 波长拉长).&lt;/p&gt;
&lt;p&gt;这就实现了&quot;高频保留, 低频拉伸&quot;的目标.&lt;/p&gt;
&lt;h3&gt;3.4 NTK-aware 的效果&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;不微调也能用!&lt;/strong&gt; 这是 NTK-aware 的最大优点——把 base 改大后, 模型在短上下文上的表现几乎不受影响 (因为高频没动), 在长上下文上的表现有显著提升.&lt;/p&gt;
&lt;p&gt;原因: 高频维度决定了模型对&lt;strong&gt;相邻位置&lt;/strong&gt;的区分能力——只要这个能力保住了, 模型在短序列上的输出就不会大变. 低频维度被拉伸后, 模型虽然可能在长距离依赖上&quot;感觉&quot;不太准, 但至少不会输出乱码.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. 方案三: YaRN — 精细化处理&lt;/h2&gt;
&lt;p&gt;YaRN (Yet another RoPE extensioN) 在 NTK-aware 的基础上做了两个关键改进.&lt;/p&gt;
&lt;h3&gt;4.1 改进一: NTK-by-parts (逐维度差异化处理)&lt;/h3&gt;
&lt;p&gt;NTK-aware 对所有频率使用了统一的 base 缩放, 这仍然不够精细. YaRN 提出: &lt;strong&gt;应该根据每个维度的波长, 来决定它的缩放方式&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;一个维度的波长:&lt;/p&gt;
&lt;p&gt;$$
\lambda_i = \frac{2\pi}{\theta_i} = 2\pi \cdot \text{base}^{2i/d}
$$&lt;/p&gt;
&lt;p&gt;对于 $d=128$, base=10000, 各维度的波长范围:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$i=0$: $\lambda_0 \approx 2\pi \cdot 1 = 6.28$ — 每约 6 个 token 旋转一圈&lt;/li&gt;
&lt;li&gt;$i=32$: $\lambda_{32} \approx 2\pi \cdot 10000^{64/128} = 2\pi \cdot 100 = 628$&lt;/li&gt;
&lt;li&gt;$i=63$: $\lambda_{63} \approx 2\pi \cdot 10000^{126/128} \approx 2\pi \cdot 9341 \approx 58680$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;现在来看跟训练长度的关系. 假设 $L_{\text{train}} = 4096$:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;如果 $\lambda_i \ll L_{\text{train}}$: 维度在训练范围内旋转了很多圈, 携带精细位置信息 → &lt;strong&gt;不缩放&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;如果 $\lambda_i \gg L_{\text{train}}$: 维度在训练范围内才转了不到半圈, 携带大范围信息 → &lt;strong&gt;用 PI 方式缩放&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;如果 $\lambda_i \approx L_{\text{train}}$: 介于两者之间 → &lt;strong&gt;平滑过渡&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;YaRN 的决策边界:&lt;/p&gt;
&lt;p&gt;$$
r_i = \begin{cases}
1, &amp;amp; \lambda_i \leq \frac{L_{\text{train}}}{2} \
\frac{1}{s}, &amp;amp; \lambda_i \geq L_{\text{train}} \
1 - (1 - \frac{1}{s}) \cdot \frac{\lambda_i - L_{\text{train}}/2}{L_{\text{train}}/2}, &amp;amp; \text{otherwise}
\end{cases}
$$&lt;/p&gt;
&lt;p&gt;其中 $r_i$ 是对 $\theta_i$ 的缩放因子: $\theta_i&apos; = \theta_i \cdot r_i$.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;解释:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;当 $\lambda_i \leq L_{\text{train}}/2$: 波长短, 旋转快, 高频 → $r_i = 1$, 完全不缩放&lt;/li&gt;
&lt;li&gt;当 $\lambda_i \geq L_{\text{train}}$: 波长长, 旋转慢, 低频 → $r_i = 1/s$, 完全用 PI 方式 (等效于位置压缩)&lt;/li&gt;
&lt;li&gt;中间: $r_i$ 从 1 线性下降到 $1/s$, 平滑过渡&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个&quot;波长 vs 训练长度&quot;的判断标准非常巧妙——它从&lt;strong&gt;物理意义&lt;/strong&gt;(旋转一圈需要多少 token)出发, 而不是从&lt;strong&gt;编号&lt;/strong&gt;(维度 $i$ 的序号)出发. 同样的维度索引在不同的模型维度 $d$ 下可能需要不同的处理, 但波长是绝对的.&lt;/p&gt;
&lt;h3&gt;4.2 改进二: Attention Temperature 调整&lt;/h3&gt;
&lt;p&gt;这是 YaRN 一个容易被忽略但非常重要的改进.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;问题&lt;/strong&gt;: 当改变频率 (不管是用 PI 还是 NTK) 后, query 和 key 的内积分布会发生变化.&lt;/p&gt;
&lt;p&gt;回顾 RoPE 文章中的推导, 对于随机向量 $\mathbf{q}, \mathbf{k}$, RoPE 编码后内积的期望是:&lt;/p&gt;
&lt;p&gt;$$
\mathbb{E}[\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle] = \sum_{i=0}^{d/2-1} \cos((m-n)\theta_i)
$$&lt;/p&gt;
&lt;p&gt;当频率 $\theta_i$ 被缩放后, 这个求和的值会发生变化. 具体来说:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原始&lt;/strong&gt;: $\sum \cos((m-n)\theta_i)$&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;PI 后&lt;/strong&gt;: $\sum \cos((m-n)\theta_i / s)$ — 因为位置被压缩了, 所以相对差对应的角度也变了&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;NTK 后&lt;/strong&gt;: $\sum \cos((m-n)\theta_i \cdot \alpha^{-2i/d})$ — 每个频率的缩放不同&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个变化导致: &lt;strong&gt;即使对相同的相对距离 $(m-n)$, 内积的绝对大小也变了&lt;/strong&gt;. 如果内积整体变小了, softmax 后的 attention 分布就会更&quot;平坦&quot; (温度变高); 如果内积变大了, attention 分布就更&quot;尖锐&quot;.&lt;/p&gt;
&lt;p&gt;YaRN 的解决方案: 在 attention softmax 中引入一个温度系数 $t$:&lt;/p&gt;
&lt;p&gt;$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d} \cdot t}\right)
$$&lt;/p&gt;
&lt;p&gt;$t$ 的选择: YaRN 论文通过分析内积分布的方差变化, 给出 $t \approx \sqrt{1 + \frac{\ln s}{\ln (d/2)}}$ 的参考值. 在实践中, $t$ 通常在 $1.0$ 到 $2.0$ 之间, 需要根据具体模型和扩展比例来调.&lt;/p&gt;
&lt;h3&gt;4.3 YaRN 的完整算法&lt;/h3&gt;
&lt;p&gt;总结一下 YaRN 的完整流程:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入: 原始 RoPE 频率 θ_i, 训练长度 L_train, 扩展比 s = L_infer / L_train

1. 计算每个维度的波长: λ_i = 2π / θ_i

2. 计算逐维度缩放因子 r_i:
   if λ_i ≤ L_train/2:     r_i = 1
   elif λ_i ≥ L_train:     r_i = 1/s
   else:                   r_i = 1 - (1 - 1/s) * (λ_i - L_train/2) / (L_train/2)

3. 应用缩放: θ_i&apos; = θ_i · r_i

4. 计算 attention 温度系数 t (经验值, 可调)

5. 使用 θ_i&apos; 做 RoPE, 用 t 调节 attention softmax
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.4 YaRN 的效果为什么更好&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;方案&lt;/th&gt;
&lt;th&gt;高频 (i=0)&lt;/th&gt;
&lt;th&gt;中频 (i=16)&lt;/th&gt;
&lt;th&gt;低频 (i=63)&lt;/th&gt;
&lt;th&gt;温度调整&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;直接外推&lt;/td&gt;
&lt;td&gt;原样&lt;/td&gt;
&lt;td&gt;原样&lt;/td&gt;
&lt;td&gt;原样&lt;/td&gt;
&lt;td&gt;无&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PI&lt;/td&gt;
&lt;td&gt;全部 $/s$&lt;/td&gt;
&lt;td&gt;全部 $/s$&lt;/td&gt;
&lt;td&gt;全部 $/s$&lt;/td&gt;
&lt;td&gt;无&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;NTK&lt;/td&gt;
&lt;td&gt;几乎不变&lt;/td&gt;
&lt;td&gt;略微降低&lt;/td&gt;
&lt;td&gt;大幅降低&lt;/td&gt;
&lt;td&gt;无&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;YaRN&lt;/td&gt;
&lt;td&gt;完全不变&lt;/td&gt;
&lt;td&gt;平滑过渡&lt;/td&gt;
&lt;td&gt;PI 缩放&lt;/td&gt;
&lt;td&gt;✅&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;YaRN 的&quot;精细&quot;之处在于: 它让每个频率的缩放决策有了&lt;strong&gt;物理依据&lt;/strong&gt;(波长), 而不是统一的数学公式.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;5. 实验对比&lt;/h2&gt;
&lt;p&gt;在 LongBench 上的典型结果 (来自 YaRN 论文, 各方法经过微调):&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;方法&lt;/th&gt;
&lt;th&gt;扩展 8x 后 LongBench 得分&lt;/th&gt;
&lt;th&gt;短序列质量是否受影响&lt;/th&gt;
&lt;th&gt;是否需要微调&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;直接外推 (无处理)&lt;/td&gt;
&lt;td&gt;~20&lt;/td&gt;
&lt;td&gt;是 (很差)&lt;/td&gt;
&lt;td&gt;否 (但效果差)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PI&lt;/td&gt;
&lt;td&gt;~37&lt;/td&gt;
&lt;td&gt;轻微下降&lt;/td&gt;
&lt;td&gt;需微调&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;NTK-aware&lt;/td&gt;
&lt;td&gt;~34&lt;/td&gt;
&lt;td&gt;几乎不变&lt;/td&gt;
&lt;td&gt;可不用微调&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;YaRN&lt;/td&gt;
&lt;td&gt;~41&lt;/td&gt;
&lt;td&gt;几乎不变&lt;/td&gt;
&lt;td&gt;少量微调即可&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;YaRN 是综合效果最好的方案. 如今主流的大模型 (LLaMA-3.1 128K, Mistral 32K, Qwen2.5 128K) 背后的长上下文扩展方案, 基本都基于类似 YaRN 的思路.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;6. HuggingFace 使用示例&lt;/h2&gt;
&lt;pre&gt;&lt;code&gt;from transformers import AutoModelForCausalLM, AutoConfig

model_name = &quot;meta-llama/Llama-2-7b-hf&quot;
config = AutoConfig.from_pretrained(model_name)

# 启用 YaRN
config.rope_scaling = {
    &quot;type&quot;: &quot;yarn&quot;,
    &quot;factor&quot;: 8.0,                              # 4K → 32K
    &quot;original_max_position_embeddings&quot;: 4096,
}

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,
    device_map=&quot;auto&quot;,
)

# 现在可以处理 32K 序列
outputs = model.generate(inputs, max_new_tokens=256)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;手动实现 (核心部分):&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;def yarn_frequencies(dim, seq_len, base=10000, scale=8.0, L_train=4096):
    &quot;&quot;&quot;计算 YaRN 的 RoPE 频率&quot;&quot;&quot;
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 波长
    wavelengths = 2 * torch.pi / inv_freq
    
    # 缩放因子: 按波长分配
    ramp = torch.clamp(
        (wavelengths - L_train/2) / (L_train - L_train/2),
        min=0.0, max=1.0
    )
    r = 1 - ramp * (1 - 1/scale)
    
    return inv_freq / r  # 频率 = 1/scale → 缩放
&lt;/code&gt;&lt;/pre&gt;
&lt;hr /&gt;
&lt;h2&gt;7. 总结&lt;/h2&gt;
&lt;p&gt;回头看这个演进过程, 每一步都在解决上一步的问题:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;步骤&lt;/th&gt;
&lt;th&gt;方法&lt;/th&gt;
&lt;th&gt;核心洞察&lt;/th&gt;
&lt;th&gt;解决的问题&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;1&lt;/td&gt;
&lt;td&gt;直接外推&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;PI&lt;/td&gt;
&lt;td&gt;把长位置缩回训练范围&lt;/td&gt;
&lt;td&gt;解决了分布外问题&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;3&lt;/td&gt;
&lt;td&gt;NTK-aware&lt;/td&gt;
&lt;td&gt;高频分辨率和低频范围不同&lt;/td&gt;
&lt;td&gt;PI 的高频分辨率损失&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;4&lt;/td&gt;
&lt;td&gt;YaRN&lt;/td&gt;
&lt;td&gt;波长决定缩放策略 + 温度修正&lt;/td&gt;
&lt;td&gt;NTK 的粗粒度问题 + 分布偏移&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;&lt;strong&gt;核心思想&lt;/strong&gt;: 不要把 RoPE 的频率当成固定的, 而是根据目标任务(上下文长度)来调整. 调整的粒度越细(逐维度 vs 全局), 效果越好. 调整后别忘了修正 attention 的热度——因为频率变了, attention 的分布也会变.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;Chen et al., Extending Context Window of Large Language Models via Positional Interpolation. 2023. &lt;a href=&quot;https://arxiv.org/abs/2306.15595&quot;&gt;arXiv:2306.15595&lt;/a&gt; — &lt;strong&gt;PI&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Peng et al., YaRN: Efficient Context Window Extension of Large Language Models. 2023. &lt;a href=&quot;https://arxiv.org/abs/2309.00071&quot;&gt;arXiv:2309.00071&lt;/a&gt; — &lt;strong&gt;YaRN&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;NTK-aware RoPE scaling. Reddit r/LocalLLaMA by u/emozilla (Jeffrey Quesnelle). &lt;a href=&quot;https://www.reddit.com/r/LocalLLaMA/comments/14lz7j9/ntkaware_scaled_rope_allows_llama_models_to_have/&quot;&gt;Link&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;Su et al., RoFormer: Enhanced Transformer with Rotary Position Embedding. Neurocomputing 2022. &lt;a href=&quot;https://arxiv.org/abs/2104.09864&quot;&gt;arXiv:2104.09864&lt;/a&gt; — &lt;strong&gt;RoPE 原文&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Bai et al., LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding. 2023. &lt;a href=&quot;https://arxiv.org/abs/2308.14508&quot;&gt;arXiv:2308.14508&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>Tokenization 完全指南 — 从字符到子词，模型到底是怎么&quot;看懂&quot;文字的</title><link>https://xuchenhui.cc/posts/2026-05-16-llm-tokenization-guide/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-16-llm-tokenization-guide/</guid><description>从&quot;模型只能读数字&quot;这个最朴素的问题出发，一步步推出字符级、词级、BPE、WordPiece、Unigram、SentencePiece 的原理和代码实现。</description><pubDate>Sat, 16 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;你有没有好奇过一个问题: &lt;strong&gt;语言模型读的是文字, 但它存的都是数字——中间这一步是怎么转过来的?&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;当你输入 &quot;ChatGPT 真厉害&quot; 时, 模型实际上看到的是这样的东西:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[1212, 8463, 129, 93127, 4273]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这一串数字是怎么来的? 这就是 &lt;strong&gt;Tokenizer&lt;/strong&gt; (分词器) 做的事.&lt;/p&gt;
&lt;p&gt;但 Tokenizer 远不止&quot;把文字转成数字&quot;这么简单. 选什么样的 Tokenizer, 直接决定了:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;模型能处理的词表有多大 (几万 vs 几十万)&lt;/li&gt;
&lt;li&gt;生僻词能不能表示 (会不会出现 OOV: Out-of-Vocabulary)&lt;/li&gt;
&lt;li&gt;不同语言的表现 (英语好不代表中文也好)&lt;/li&gt;
&lt;li&gt;序列有多长 (token 越多, 计算越贵)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这篇文章就来完整梳理 Tokenization 的技术演进: 从最朴素的字符级/词级方案, 到 BPE, 到 WordPiece, 再到 Unigram 和 SentencePiece.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;1. 先想清楚: 难点在哪&lt;/h2&gt;
&lt;p&gt;假设你有一句话: &lt;code&gt;I love machine learning&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;模型是数学机器, 它只能处理数字. 所以你需要一个映射函数 $f$:&lt;/p&gt;
&lt;p&gt;$$
f(\text{&quot;I love machine learning&quot;}) \to [x_1, x_2, ..., x_n]
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方案一: 字符级别 (Character-level)&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;把每个字符映射到一个数字:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# &apos; &apos;→1, &apos;I&apos;→2, &apos;a&apos;→3, &apos;c&apos;→4, &apos;e&apos;→5, &apos;g&apos;→6, &apos;h&apos;→7, &apos;i&apos;→8, &apos;l&apos;→9, &apos;m&apos;→10, &apos;n&apos;→11, &apos;o&apos;→12, &apos;r&apos;→13, &apos;v&apos;→14
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;结果: &lt;code&gt;[2, 1, 9, 12, 14, 5, 1, 10, 3, 4, 7, 8, 11, 5, 1, 9, 5, 3, 13, 11, 8, 11, 6]&lt;/code&gt; — 23 个 token (和原始字符串的字符数一样).&lt;/p&gt;
&lt;p&gt;词表很小 (26 个字母 + 标点 ≈ 100 左右), 但序列太长, 而且&lt;strong&gt;单个字符没有语义&lt;/strong&gt;——模型很难从 &apos;m&apos; 这个字符学会 &quot;machine&quot; 是什么意思.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方案二: 词级别 (Word-level)&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;把每个词映射到一个数字:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;&apos;I&apos; → 423, &apos;love&apos; → 156, &apos;machine&apos; → 8971, &apos;learning&apos; → 442
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;结果: &lt;code&gt;[423, 156, 8971, 442]&lt;/code&gt; — 4 个 token, 每个都有明确的语义.&lt;/p&gt;
&lt;p&gt;但问题来了: 词表要多大? 英语有几十万个词, 加上专有名词、拼写变体、新造词, 词表轻易突破百万. 而且遇到没见过的词 (OOV: Out-of-Vocabulary) 就只能给 &lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt; (未知标记), 信息直接丢失.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方案三: 子词级别 (Subword-level) — 黄金平衡点&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;子词级别做了个巧妙的平衡: 常见的词保持完整 (如 &quot;love&quot;), 不常见的词拆成更小的子词 (如 &quot;tokenization&quot; → &quot;token&quot; + &quot;ization&quot;).&lt;/p&gt;
&lt;p&gt;这就是当前所有主流模型使用的方案. 接下来详细介绍各个子词方案.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;2. BPE (Byte Pair Encoding) — 最常用的方案&lt;/h2&gt;
&lt;p&gt;BPE 最早是 1994 年提出的一种数据压缩算法, 被 Sennrich 等人 (2016) 引入 NLP. 现在 GPT 系列、LLaMA、Bloom 等都用它.&lt;/p&gt;
&lt;h3&gt;2.1 核心思路&lt;/h3&gt;
&lt;p&gt;BPE 的核心是: &lt;strong&gt;从最小的单位(字符)开始, 逐步合并出现频率最高的 Pair, 直到达到目标词表大小&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;这不是直接给你一个词表, 而是一个&lt;strong&gt;合并规则列表&lt;/strong&gt;——你知道先合并什么、后合并什么, 就能把任意文本切分成 token.&lt;/p&gt;
&lt;h3&gt;2.2 一步一步的推导 (带例子)&lt;/h3&gt;
&lt;p&gt;假设我们有一个小语料库, 包含以下 5 个词 (每个重复多次):&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;low low low low low low low low low low   (10 个 &quot;low&quot;)
lowest lowest lowest lowest lowest        (5 个 &quot;lowest&quot;)
newer newer newer newer newer             (5 个 &quot;newer&quot;)
wider wider wider wider wider             (5 个 &quot;wider&quot;)
new new                                  (2 个 &quot;new&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;目标词表大小: 设为 18 个 token. (初始有 11 个独立字符, 每次合并新增 1 个 token, 所以需要做 18 − 11 = 7 次合并.)&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;第 1 步: 初始化&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;把所有词拆成字符, 加上 &lt;code&gt;&amp;lt;/w&amp;gt;&lt;/code&gt; (词尾标记, 表示一个词的结束):&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;l o w &amp;lt;/w&amp;gt;    (10 次)
l o w e s t &amp;lt;/w&amp;gt;  (5 次)
n e w e r &amp;lt;/w&amp;gt;   (5 次)
w i d e r &amp;lt;/w&amp;gt;   (5 次)
n e w &amp;lt;/w&amp;gt;       (2 次)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;初始 token 集合: &lt;code&gt;{l, o, w, e, s, t, n, r, i, d, &amp;lt;/w&amp;gt;}&lt;/code&gt; — 11 个.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;第 2 步: 统计 Pair 频率&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;统计所有相邻字符 Pair 的出现次数:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Pair&lt;/th&gt;
&lt;th&gt;计数&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;(l, o)&lt;/td&gt;
&lt;td&gt;10+5 = 15&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(o, w)&lt;/td&gt;
&lt;td&gt;10+5 = 15&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(w, &amp;lt;/w&amp;gt;)&lt;/td&gt;
&lt;td&gt;10+2 = 12&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(e, s)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(s, t)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(t, &amp;lt;/w&amp;gt;)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(n, e)&lt;/td&gt;
&lt;td&gt;5+2 = 7&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(w, e)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(e, r)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(r, &amp;lt;/w&amp;gt;)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(w, i)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(i, d)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(d, e)&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(e, w)&lt;/td&gt;
&lt;td&gt;5+2 = 7&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;&lt;strong&gt;第 3 步: 合并频率最高的 Pair&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;最高频的 Pair 是 (l, o) 和 (o, w) 并列第一, 都是 15 次. 选 (l, o) 先合并.&lt;/p&gt;
&lt;p&gt;合并后, 所有 &quot;l o&quot; 变成 &quot;lo&quot;. 现在词汇变成了:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;lo w &amp;lt;/w&amp;gt;          (10 次)
lo w e s t &amp;lt;/w&amp;gt;   (5 次)
n e w e r &amp;lt;/w&amp;gt;    (5 次)
w i d e r &amp;lt;/w&amp;gt;    (5 次)
n e w &amp;lt;/w&amp;gt;        (2 次)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;加入新 token: &lt;code&gt;lo&lt;/code&gt;. 当前词表: 12 个.&lt;/p&gt;
&lt;p&gt;注意一个关键点: (o, w) 这个 Pair &lt;strong&gt;不再存在了&lt;/strong&gt;——因为原来的 &quot;l o&quot; 已经合并成了 &quot;lo&quot;, &quot;o&quot; 不再是独立 token, 所以 (o, w) 无法被合并.&lt;/p&gt;
&lt;p&gt;所以第二次合并时, 我们要重新统计:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Pair&lt;/th&gt;
&lt;th&gt;计数&lt;/th&gt;
&lt;th&gt;来自&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;(lo, w)&lt;/td&gt;
&lt;td&gt;10+5 = 15&lt;/td&gt;
&lt;td&gt;low + lowest&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(w, &amp;lt;/w&amp;gt;)&lt;/td&gt;
&lt;td&gt;10+2 = 12&lt;/td&gt;
&lt;td&gt;low + new&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(w, e)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;td&gt;lowest + newer&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(e, r)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;td&gt;newer + wider&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(r, &amp;lt;/w&amp;gt;)&lt;/td&gt;
&lt;td&gt;5+5 = 10&lt;/td&gt;
&lt;td&gt;newer + wider&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(n, e)&lt;/td&gt;
&lt;td&gt;5+2 = 7&lt;/td&gt;
&lt;td&gt;newer + new&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;(e, w)&lt;/td&gt;
&lt;td&gt;5+2 = 7&lt;/td&gt;
&lt;td&gt;newer + new&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;...&lt;/td&gt;
&lt;td&gt;...&lt;/td&gt;
&lt;td&gt;...&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;最高频的是 (lo, w) = 15 次. 合并它:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;low &amp;lt;/w&amp;gt;          (10 次)
low e s t &amp;lt;/w&amp;gt;   (5 次)   — &quot;lo w&quot; 变成了 &quot;low&quot;
n e w e r &amp;lt;/w&amp;gt;   (5 次)
w i d e r &amp;lt;/w&amp;gt;   (5 次)
n e w &amp;lt;/w&amp;gt;       (2 次)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;加入新 token: &lt;code&gt;low&lt;/code&gt;. 当前词表: 13 个.&lt;/p&gt;
&lt;p&gt;就这样反复合并, 直到词表达到目标大小. 每次合并后, 所有序列都会更新, 然后重新统计 Pair 频率.&lt;/p&gt;
&lt;p&gt;最终 BPE 学到的一系列合并规则类似:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;(l, o) → lo
(lo, w) → low
(low, &amp;lt;/w&amp;gt;) → low&amp;lt;/w&amp;gt;
(e, r) → er
(n, e) → ne
(ne, w) → new
...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;有了这些规则, 任何新词都可以按同样的方式切分. 比如 &quot;newest&quot;:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;拆成字符: &lt;code&gt;n e w e s t &amp;lt;/w&amp;gt;&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;按规则合并: &lt;code&gt;ne w e s t &amp;lt;/w&amp;gt;&lt;/code&gt; → &lt;code&gt;new e s t &amp;lt;/w&amp;gt;&lt;/code&gt; (不能再合并了)&lt;/li&gt;
&lt;li&gt;最终: &lt;code&gt;[&apos;new&apos;, &apos;e&apos;, &apos;s&apos;, &apos;t&apos;, &apos;&amp;lt;/w&amp;gt;&apos;]&lt;/code&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;2.3 推理时如何分词&lt;/h3&gt;
&lt;p&gt;训练完后, BPE 不再统计频率, 而是直接按&lt;strong&gt;学到的合并规则&lt;/strong&gt;来切分.&lt;/p&gt;
&lt;p&gt;具体步骤:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;把输入文本拆成基本单元 (字符或 byte, 取决于具体实现)&lt;/li&gt;
&lt;li&gt;从左到右扫描, 看能不能应用学到的合并规则&lt;/li&gt;
&lt;li&gt;尽可能合并, 直到不能再合并为止&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;对于 GPT-2 和 GPT-3 系列, 使用的是 &lt;strong&gt;Byte-level BPE&lt;/strong&gt;: 最小单位不是字符, 而是 &lt;strong&gt;byte&lt;/strong&gt; (字节). 所以&quot;拆成基本单元&quot;这一步是把每个 Unicode 字符先编码为 1-4 个 byte, 然后再基于 byte 序列做合并. 这样做的好处是: &lt;strong&gt;可以表示任意 Unicode 字符&lt;/strong&gt;——不管中文、日文、emoji, 都能拆成 byte 序列来处理, 不会出现 &lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;GPT-2 的词表大小是 50,257. LLaMA 系列的词表是 32,000 (LLaMA-1) 或 128,000 (LLaMA-3).&lt;/p&gt;
&lt;h3&gt;2.4 BPE 的优缺点&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优点:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;简单直观, 实现容易&lt;/li&gt;
&lt;li&gt;对常见词保持完整, 对罕见词拆分成子词&lt;/li&gt;
&lt;li&gt;Byte-level 变体可以处理任意文本&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;缺点:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;基于频率而不是基于&lt;strong&gt;语义相关性&lt;/strong&gt;, 可能会把语义相关的词拆开 (如 &quot;unbelievable&quot; → &quot;un&quot; + &quot;believ&quot; + &quot;able&quot;, 可能不如 &quot;un&quot; + &quot;believe&quot; + &quot;able&quot; 好)&lt;/li&gt;
&lt;li&gt;频率统计完全由语料决定, 语料偏斜会导致 token 分布不合理&lt;/li&gt;
&lt;/ul&gt;
&lt;hr /&gt;
&lt;h2&gt;3. WordPiece — BERT 的选择&lt;/h2&gt;
&lt;p&gt;WordPiece 是 Google 为 BERT 开发的 tokenizer, 和 BPE 非常相似, 但合并标准不同.&lt;/p&gt;
&lt;h3&gt;3.1 和 BPE 的区别&lt;/h3&gt;
&lt;p&gt;BPE 合并的是&lt;strong&gt;频率最高&lt;/strong&gt;的 token pair.&lt;/p&gt;
&lt;p&gt;WordPiece 合并的是&lt;strong&gt;让语料库似然 (likelihood) 提升最大&lt;/strong&gt;的 token pair.&lt;/p&gt;
&lt;p&gt;具体来说, WordPiece 计算每个候选 pair 的 score:&lt;/p&gt;
&lt;p&gt;$$
\text{score}(x, y) = \frac{\text{freq}(xy)}{\text{freq}(x) \times \text{freq}(y)}
$$&lt;/p&gt;
&lt;p&gt;其中 $\text{freq}(xy)$ 是合并后的 token 在语料中的频率, $\text{freq}(x)$ 和 $\text{freq}(y)$ 是两个原始 token 的频率.&lt;/p&gt;
&lt;p&gt;为什么要用这个公式? 直观理解:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;如果 $x$ 和 $y$ 经常一起出现 ($\text{freq}(xy)$ 大), 但各自出现得少 ($\text{freq}(x)$ 和 $\text{freq}(y)$ 小), 说明它们的组合特别有意义, score 高&lt;/li&gt;
&lt;li&gt;如果 $x$ 和 $y$ 各自出现得很多, 偶尔碰到一起, 那它们的组合可能只是巧合, score 低&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这比 BPE 的纯频率更&quot;智能&quot;——它倾向于合并那些&lt;strong&gt;共现特别紧密&lt;/strong&gt;的 pair, 而不是单纯出现多的 pair.&lt;/p&gt;
&lt;h3&gt;3.2 例子&lt;/h3&gt;
&lt;p&gt;继续用上面的语料. 假设目前 token 集包含 &lt;code&gt;l&lt;/code&gt;, &lt;code&gt;o&lt;/code&gt;, &lt;code&gt;w&lt;/code&gt;, &lt;code&gt;e&lt;/code&gt;, &lt;code&gt;s&lt;/code&gt; 等字符, 我们需要决定先合并哪一对.&lt;/p&gt;
&lt;p&gt;计算 $\text{score}(\text{l}, \text{o})$:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$\text{freq}(\text{lo})$: &quot;l&quot; 和 &quot;o&quot; 相邻出现 15 次 (low ×10, lowest ×5)&lt;/li&gt;
&lt;li&gt;$\text{freq}(\text{l})$: &quot;l&quot; 本身出现 15 次 (都在词首)&lt;/li&gt;
&lt;li&gt;$\text{freq}(\text{o})$: &quot;o&quot; 出现 15 次&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$\text{score}(\text{l}, \text{o}) = \frac{15}{15 \times 15} = \frac{1}{15} \approx 0.067$$&lt;/p&gt;
&lt;p&gt;再看 $(\text{o}, \text{w})$:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$\text{freq}(\text{ow})$: &quot;o&quot; 和 &quot;w&quot; 相邻出现 15 次 (low ×10, lowest ×5)&lt;/li&gt;
&lt;li&gt;$\text{freq}(\text{o})$: 15&lt;/li&gt;
&lt;li&gt;$\text{freq}(\text{w})$: &lt;strong&gt;27&lt;/strong&gt; — 注意 &quot;w&quot; 不仅出现在 low (×10)、lowest (×5), 还出现在 newer (×5)、wider (×5)、new (×2) 中. 所以是 $10+5+5+5+2 = 27$.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$\text{score}(\text{o}, \text{w}) = \frac{15}{15 \times 27} \approx 0.037$$&lt;/p&gt;
&lt;p&gt;这里可以发现区别:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;(l, o) 的 score 0.067 &amp;gt; (o, w) 的 0.037, 因为 &lt;code&gt;l&lt;/code&gt; 几乎只出现在 &lt;code&gt;o&lt;/code&gt; 前面——$\text{freq}(\text{l})$ 和 $\text{freq}(\text{(l, o)})$ 几乎相等, 说明 &quot;l&quot; 和 &quot;o&quot; 是近乎绑定在一起的.&lt;/li&gt;
&lt;li&gt;而 &lt;code&gt;w&lt;/code&gt; 还可以接 &lt;code&gt;e&lt;/code&gt; (newer)、&lt;code&gt;i&lt;/code&gt; (wider)、&lt;code&gt;&amp;lt;/w&amp;gt;&lt;/code&gt; (low, new), 所以 &lt;code&gt;o&lt;/code&gt; + &lt;code&gt;w&lt;/code&gt; 不是唯一的组合——score 自然更低.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;这就是 WordPiece 的本质&lt;/strong&gt;——它衡量的是 $x$ 和 $y$ 的&lt;strong&gt;共现强度&lt;/strong&gt;, 而不是共现频率. 公式 $\text{score}(x, y) = \frac{\text{freq}(xy)}{\text{freq}(x) \cdot \text{freq}(y)}$ 实际上就是 &lt;strong&gt;PMI (Pointwise Mutual Information)&lt;/strong&gt; 的变形:&lt;/p&gt;
&lt;p&gt;$$\text{PMI}(x, y) = \log \frac{P(x, y)}{P(x) P(y)} \propto \log \frac{\text{freq}(xy)}{\text{freq}(x) \cdot \text{freq}(y)}$$&lt;/p&gt;
&lt;p&gt;WordPiece 不取 log (为了保留排序等价性), 但思想完全一样: &lt;strong&gt;如果两个 token 同时出现远高于随机期望, 就应该合并它们&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;3.3 实际使用&lt;/h3&gt;
&lt;p&gt;WordPiece 在 BERT 和 DistilBERT 中使用 (RoBERTa 用的是 Byte-level BPE, 与 GPT-2 相同). 词表大小通常是 30,000 左右. 和 BPE 一样, 训练完成后就是一个确定的合并规则集.&lt;/p&gt;
&lt;p&gt;BERT 的 WordPiece 会在 token 前面加 &lt;code&gt;##&lt;/code&gt; 表示这不是词的开头 (continuation). 比如:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;&quot;unbelievable&quot; → [&quot;un&quot;, &quot;##believe&quot;, &quot;##able&quot;]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这种标记方式让模型能区分&quot;词首&quot;和&quot;词中&quot;的子词——同一个子词在不同位置可能表示不同含义.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;4. Unigram — 由概率说话&lt;/h2&gt;
&lt;p&gt;Unigram 是另一种子词分词方案, 和 BPE/WordPiece 的思路完全相反.&lt;/p&gt;
&lt;h3&gt;4.1 BPE/WordPiece 是&quot;自底向上&quot;, Unigram 是&quot;自顶向下&quot;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;BPE/WordPiece&lt;/strong&gt;: 从字符开始, 逐步合并 → 自底向上&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Unigram&lt;/strong&gt;: 从一个很大的候选词表开始, 逐步删除&quot;不重要&quot;的 token → 自顶向下&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Unigram 的具体做法:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;先准备一个很大的候选词表 (比如从语料中提取所有可能的子词序列)&lt;/li&gt;
&lt;li&gt;对每个 token, 计算删除它后对语料似然的影响&lt;/li&gt;
&lt;li&gt;删除影响最小的那些 token&lt;/li&gt;
&lt;li&gt;重复直到达到目标词表大小&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;4.2 损失函数&lt;/h3&gt;
&lt;p&gt;Unigram 的词表学习和分词不是同时做的, 流程是: 先用 EM 算法迭代估计 token 概率 $P(t)$, 然后用 Viterbi 找最优切分.&lt;/p&gt;
&lt;p&gt;给定词表 $V$ 和 token 概率 $P(t)$, 一个句子 $X$ 的最优分词由 Viterbi 算法找到:&lt;/p&gt;
&lt;p&gt;$$
\operatorname*{argmax}&lt;em&gt;{t_1, \ldots, t_k \in V} \sum&lt;/em&gt;{i=1}^{k} \log P(t_i)
$$&lt;/p&gt;
&lt;p&gt;即: 穷举所有可能的分词方式, 选出使概率和最大的那一种. Viterbi 用动态规划做这个, 实际实现中用 trie 树限制候选数, 复杂度约 $O(L \cdot M)$, 其中 $M$ 是最大 token 长度 (通常 ≤ 50).&lt;/p&gt;
&lt;p&gt;训练过程的核心是 &lt;strong&gt;EM 算法&lt;/strong&gt; (标准做法是 soft EM, 即用 forward-backward 算法计算每个 token 在各分词路径上的期望出现次数; 以下描述 hard EM 变体):&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;E 步&lt;/strong&gt;: 对每个句子, 用 Viterbi 找到当前最优分词, 统计每个 token 被选中的次数&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;M 步&lt;/strong&gt;: 根据频率更新 $P(t) = \frac{\text{count}(t)}{\sum_{v \in V} \text{count}(v)}$&lt;/li&gt;
&lt;li&gt;重复直到收敛&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;收敛后, 对每个 token 评估&lt;strong&gt;删除它带来的似然损失&lt;/strong&gt;, 删除损失最小的 token, 然后重新训练. 重复直到词表降到目标大小.&lt;/p&gt;
&lt;h3&gt;4.3 和 BPE/WordPiece 的对比&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;Unigram 的优势:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;能显式控制词表大小 — 不像 BPE 需要逐步合并到目标大小&lt;/li&gt;
&lt;li&gt;学习到的 token 更&quot;合理&quot; — 概率框架比纯频率统计更鲁棒&lt;/li&gt;
&lt;li&gt;可以输出多个候选分词路径 (不只是最优的那一种)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;劣势:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;需要 Viterbi 解码, 比 BPE 的贪心匹配慢&lt;/li&gt;
&lt;li&gt;训练过程更复杂&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;实际使用:&lt;/strong&gt; Unigram 是 SentencePiece 的默认算法, 也用在 T5、ALBERT 等模型中.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;5. SentencePiece — 把一切统一起来&lt;/h2&gt;
&lt;p&gt;前面讲的 BPE、WordPiece、Unigram, 都有一个共同的前提: &lt;strong&gt;输入文本需要先按空格分成&quot;词&quot;&lt;/strong&gt; (pre-tokenization). 这对英语没问题, 但对中文、日文等没有空格的文字来说就尴尬了——&quot;我喜欢你&quot; 应该被切分成什么?&lt;/p&gt;
&lt;p&gt;SentencePiece 解决了这个问题: &lt;strong&gt;它把原始文本直接当作 Unicode 字符序列处理, 不需要 pre-tokenization&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;5.1 SentencePiece 的核心思想&lt;/h3&gt;
&lt;p&gt;SentencePiece 的输入不是&quot;词&quot;, 而是&lt;strong&gt;原始字符串&lt;/strong&gt; (Unicode 字符序列). 它把整个文本看作一个字符流, 然后用 BPE 或 Unigram 算法来训练.&lt;/p&gt;
&lt;p&gt;关键特性:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;空格也是字符&lt;/strong&gt;: 英语中的空格被当作普通字符, 用 &lt;code&gt;▁&lt;/code&gt; (下划线符号) 表示. 分词结果类似于 &lt;code&gt;▁I▁love▁machine▁learning&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;不需要语言特定预处理&lt;/strong&gt;: 中文、日文、韩文可以直接处理, 不需要先分词&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;支持 BPE 和 Unigram 两种算法&lt;/strong&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;5.2 对比: SentencePiece BPE vs SentencePiece Unigram&lt;/h3&gt;
&lt;p&gt;LLaMA、Mistral 使用的是 &lt;strong&gt;SentencePiece + BPE&lt;/strong&gt; (字符级别, 不是 byte-level). LLaMA-3 改用了 tiktoken 库的 byte-level BPE, 不再使用 SentencePiece.&lt;/p&gt;
&lt;p&gt;T5、ALBERT 使用的是 &lt;strong&gt;SentencePiece + Unigram&lt;/strong&gt;.&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;模型&lt;/th&gt;
&lt;th&gt;Tokenizer&lt;/th&gt;
&lt;th&gt;词表大小&lt;/th&gt;
&lt;th&gt;说明&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;GPT-2/3&lt;/td&gt;
&lt;td&gt;Byte-level BPE&lt;/td&gt;
&lt;td&gt;50,257&lt;/td&gt;
&lt;td&gt;可以表示任意 byte 序列&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;BERT&lt;/td&gt;
&lt;td&gt;WordPiece&lt;/td&gt;
&lt;td&gt;30,000&lt;/td&gt;
&lt;td&gt;用 ## 标记续词&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;LLaMA&lt;/td&gt;
&lt;td&gt;SentencePiece + BPE&lt;/td&gt;
&lt;td&gt;32,000&lt;/td&gt;
&lt;td&gt;字符级别, 适合中英文混合&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;LLaMA-3&lt;/td&gt;
&lt;td&gt;tiktoken BPE (byte-level)&lt;/td&gt;
&lt;td&gt;128,000&lt;/td&gt;
&lt;td&gt;大词表, 编码效率高&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;T5&lt;/td&gt;
&lt;td&gt;SentencePiece + Unigram&lt;/td&gt;
&lt;td&gt;32,000&lt;/td&gt;
&lt;td&gt;可输出多条分词路径&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Qwen&lt;/td&gt;
&lt;td&gt;tiktoken BPE (byte-level)&lt;/td&gt;
&lt;td&gt;152,000&lt;/td&gt;
&lt;td&gt;中文优化, 同 GPT-4 架构&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;GPT-4&lt;/td&gt;
&lt;td&gt;未知 (推测是 tiktoken BPE)&lt;/td&gt;
&lt;td&gt;≈100K&lt;/td&gt;
&lt;td&gt;未公开&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;hr /&gt;
&lt;h2&gt;6. 代码: 手把手看 tokenizer 怎么工作&lt;/h2&gt;
&lt;h3&gt;6.1 用 HuggingFace &lt;code&gt;tokenizers&lt;/code&gt; 训练一个 BPE&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;from tokenizers import Tokenizer, trainers, models

# 创建一个 BPE tokenizer
tokenizer = Tokenizer(models.BPE())

# 准备训练器
trainer = trainers.BpeTrainer(
    vocab_size=5000,
    special_tokens=[&quot;&amp;lt;unk&amp;gt;&quot;, &quot;&amp;lt;s&amp;gt;&quot;, &quot;&amp;lt;/s&amp;gt;&quot;],
    min_frequency=2,
    show_progress=True
)

# 准备语料
files = [&quot;corpus.txt&quot;]  # 你的文本文件

# 训练
tokenizer.train(files, trainer)

# 测试
output = tokenizer.encode(&quot;I love machine learning&quot;)
print(output.tokens)   # 类似 [&apos;I&apos;, &apos;Ġlove&apos;, &apos;Ġmachine&apos;, &apos;Ġlearning&apos;] (不同 tokenizer 输出不同)
print(output.ids)      # [32, 156, 4231, 2156]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;6.2 查看 LLaMA 的 tokenizer&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(&quot;meta-llama/Llama-2-7b-hf&quot;)

# 英文
text = &quot;Tokenization is the first step&quot;
tokens = tokenizer.tokenize(text)
ids = tokenizer.encode(text)
print(tokens)
# [&apos;▁Token&apos;, &apos;ization&apos;, &apos;▁is&apos;, &apos;▁the&apos;, &apos;▁first&apos;, &apos;▁step&apos;]
print(len(ids))  # 6 个 token

# 中文
text_cn = &quot;大语言模型分词器&quot;
tokens = tokenizer.tokenize(text_cn)
ids = tokenizer.encode(text_cn)
print(tokens)
# [&apos;▁大&apos;, &apos;语言&apos;, &apos;模型&apos;, &apos;分词&apos;, &apos;器&apos;]
print(len(ids))  # 5 个 token

# 发现 &quot;语言&quot; 是一个完整的 token!
# 这是因为 LLaMA 的训练语料也包含中文, &quot;语言&quot; 常见, 被合并了
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;6.3 token 长度分布 (示意)&lt;/h3&gt;
&lt;p&gt;不同语言的 token 效率不一样 (以下输出为近似值, 使用 LLaMA-2 tokenizer):&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;texts = [
    (&quot;English&quot;, &quot;The quick brown fox jumps over the lazy dog&quot;),
    (&quot;中文&quot;, &quot;大语言模型分词器的工作原理是什么&quot;),
    (&quot;混合&quot;, &quot;BERT 的 WordPiece 和 LLaMA 的 BPE 有什么区别&quot;),
]

for lang, text in texts:
    ids = tokenizer.encode(text)
    chars_per_token = len(text) / len(ids)
    print(f&quot;{lang}: {len(ids)} tokens, {chars_per_token:.1f} chars/token&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出类似:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;English: 10 tokens, 4.3 chars/token
中文: 10 tokens, 1.6 chars/token
混合: 15 tokens, 2.4 chars/token
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;中文的 chars/token 更低, 意味着&lt;strong&gt;同样的信息量, 中文需要的 token 更多&lt;/strong&gt;. 这就解释了为什么中文模型的上下文窗口&quot;不够用&quot;——同样的 4096 token, 英文能读约 17600 个字符, 中文只能读约 6600 个字符.&lt;/p&gt;
&lt;h3&gt;6.4 自己实现一个简化版 BPE&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;from collections import defaultdict

def train_bpe(corpus: list[str], vocab_size: int):
    &quot;&quot;&quot;训练一个简化版 BPE&quot;&quot;&quot;

    # 1. 初始化: 把词拆成字符 + 词尾标记
    words = []
    for text in corpus:
        for word in text.split():
            words.append(&quot; &quot;.join(list(word)) + &quot; &amp;lt;/w&amp;gt;&quot;)

    # 2. 初始词表 (所有字符)
    vocab = set()
    for word in words:
        for char in word.split():
            vocab.add(char)

    # 3. 重复合并
    merges = []
    while len(vocab) &amp;lt; vocab_size:
        # 统计所有 pair 频率
        pairs = defaultdict(int)
        for word in words:
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += 1
        
        if not pairs:
            break
        
        # 找频率最高的 pair
        best_pair = max(pairs, key=pairs.get)
        merges.append(best_pair)
        
        # 合并
        new_words = []
        for word in words:
            new_word = word.replace(
                f&quot;{best_pair[0]} {best_pair[1]}&quot;,
                f&quot;{best_pair[0]}{best_pair[1]}&quot;
            )
            new_words.append(new_word)
        words = new_words
        
        # 加入词表
        vocab.add(f&quot;{best_pair[0]}{best_pair[1]}&quot;)
    
    return merges

def apply_bpe(text: str, merges: list):
    &quot;&quot;&quot;用学到的合并规则分词&quot;&quot;&quot;
    words = text.split()
    result = []
    for word in words:
        tokens = list(word) + [&quot;&amp;lt;/w&amp;gt;&quot;]
        # 注意: 简化实现, 实际 BPE 需要更复杂的合并逻辑
        for merge in merges:
            i = 0
            while i &amp;lt; len(tokens) - 1:
                if tokens[i] == merge[0] and tokens[i+1] == merge[1]:
                    tokens = tokens[:i] + [f&quot;{merge[0]}{merge[1]}&quot;] + tokens[i+2:]
                else:
                    i += 1
        result.extend(tokens)
    return result

# 测试
merges = train_bpe([&quot;low low low lowest newer newer wider&quot;], vocab_size=20)
print(apply_bpe(&quot;lowest&quot;, merges))
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;注意: 这是简化版, 实际 BPE 实现需要考虑词频加权、byte-level 编码等.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;7. Tokenizer 选型指南&lt;/h2&gt;
&lt;h3&gt;7.1 不同模型的 Tokenizer&lt;/h3&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;模型&lt;/th&gt;
&lt;th&gt;Tokenizer 类型&lt;/th&gt;
&lt;th&gt;词表大小&lt;/th&gt;
&lt;th&gt;是否支持中文&lt;/th&gt;
&lt;th&gt;特殊标记&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;GPT-2&lt;/td&gt;
&lt;td&gt;Byte-level BPE&lt;/td&gt;
&lt;td&gt;50,257&lt;/td&gt;
&lt;td&gt;✅ (通过 byte)&lt;/td&gt;
&lt;td&gt;无&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;BERT&lt;/td&gt;
&lt;td&gt;WordPiece&lt;/td&gt;
&lt;td&gt;30,000&lt;/td&gt;
&lt;td&gt;❌ (需额外中文版)&lt;/td&gt;
&lt;td&gt;##&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;LLaMA-2&lt;/td&gt;
&lt;td&gt;SentencePiece + BPE&lt;/td&gt;
&lt;td&gt;32,000&lt;/td&gt;
&lt;td&gt;✅&lt;/td&gt;
&lt;td&gt;▁&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;LLaMA-3&lt;/td&gt;
&lt;td&gt;tiktoken BPE (byte-level)&lt;/td&gt;
&lt;td&gt;128,000&lt;/td&gt;
&lt;td&gt;✅&lt;/td&gt;
&lt;td&gt;无 (byte-level)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;T5&lt;/td&gt;
&lt;td&gt;SentencePiece + Unigram&lt;/td&gt;
&lt;td&gt;32,000&lt;/td&gt;
&lt;td&gt;✅&lt;/td&gt;
&lt;td&gt;▁&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;ChatGLM&lt;/td&gt;
&lt;td&gt;SentencePiece&lt;/td&gt;
&lt;td&gt;130,000&lt;/td&gt;
&lt;td&gt;✅ (中文优化)&lt;/td&gt;
&lt;td&gt;▁&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Qwen&lt;/td&gt;
&lt;td&gt;tiktoken BPE (byte-level)&lt;/td&gt;
&lt;td&gt;152,000&lt;/td&gt;
&lt;td&gt;✅ (中文优化)&lt;/td&gt;
&lt;td&gt;无 (byte-level)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;h3&gt;7.2 怎么选?&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;做英文任务&lt;/strong&gt;: Byte-level BPE 或 WordPiece 都行, 词表 30K-50K&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;做多语言任务&lt;/strong&gt;: SentencePiece + BPE/Unigram, 词表 100K 以上&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;做中文任务&lt;/strong&gt;: 选中文优化的 tokenizer (LLaMA-3, Qwen, ChatGLM), 或者自己训练一个&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;追求效率&lt;/strong&gt;: 大词表 → 每个 token 信息量更大 → 序列更短 → 计算更快 (但 embedding 层更大)&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;追求覆盖&lt;/strong&gt;: 小词表 → 基本不会遇到 &lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt; → 但序列更长&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;7.3 一个重要的权衡&lt;/h3&gt;
&lt;p&gt;词表大小和序列长度是 trade-off:&lt;/p&gt;
&lt;p&gt;计算量
$$\propto \underbrace{L^2 d}&lt;em&gt;{\text{Self-Attention}} + \underbrace{V d}&lt;/em&gt;{\text{Embedding}}$$&lt;/p&gt;
&lt;p&gt;其中 $L$ 是序列长度, $d$ 是隐藏层维度, $V$ 是词表大小. (为简洁, 忽略了 FFN 层的 $L d^2$ 项, 但 trade-off 的结论不变.)&lt;/p&gt;
&lt;p&gt;词表越大, 序列越短 (因为每个 token 包含的信息更多), Self-Attention 的 $L^2$ 项显著降低; 但 Embedding 层 $V d$ 变大.&lt;/p&gt;
&lt;p&gt;LLaMA-3 选择了 128K 的大词表 (相比 LLaMA-2 的 32K), 序列更短但嵌入层更重. 实测表明好处大于坏处.&lt;/p&gt;
&lt;hr /&gt;
&lt;h2&gt;8. 总结&lt;/h2&gt;
&lt;p&gt;Tokenization 看起来是&quot;把文字转数字&quot;的小事, 但它的设计直接影响模型的能力和效率:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;方法&lt;/th&gt;
&lt;th&gt;核心思想&lt;/th&gt;
&lt;th&gt;典型应用&lt;/th&gt;
&lt;th&gt;合并标准&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;字符级&lt;/td&gt;
&lt;td&gt;每个字符映射一个数字&lt;/td&gt;
&lt;td&gt;理论上可行&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;词级&lt;/td&gt;
&lt;td&gt;每个词一个 token&lt;/td&gt;
&lt;td&gt;N-gram 时代&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;BPE&lt;/td&gt;
&lt;td&gt;从字符开始, 逐对合并高频 pair&lt;/td&gt;
&lt;td&gt;GPT, LLaMA&lt;/td&gt;
&lt;td&gt;纯粹频率&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;WordPiece&lt;/td&gt;
&lt;td&gt;类似 BPE, 但合并标准用似然提升&lt;/td&gt;
&lt;td&gt;BERT&lt;/td&gt;
&lt;td&gt;似然提升&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Unigram&lt;/td&gt;
&lt;td&gt;从大词表开始, 逐步删除不重要 token&lt;/td&gt;
&lt;td&gt;T5, ALBERT&lt;/td&gt;
&lt;td&gt;似然损失&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;SentencePiece&lt;/td&gt;
&lt;td&gt;统一框架, 不需要 pre-tokenization&lt;/td&gt;
&lt;td&gt;LLaMA, T5&lt;/td&gt;
&lt;td&gt;可配 BPE/Unigram&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;最核心的一点: &lt;strong&gt;Tokenizer 的好坏, 决定了模型&quot;看到&quot;的是什么&lt;/strong&gt;. 一个不好的 Tokenizer 会把语义相关的词拆得七零八落, 或者把不相关的 byte 强行拼在一起——模型学到的是&quot;错位&quot;的语言知识.&lt;/p&gt;
&lt;p&gt;下次你再看到 &quot;LLaMA-3 词表 128K&quot; 这种数字时, 就知道它意味着什么了.&lt;/p&gt;
&lt;hr /&gt;
&lt;h3&gt;参考资料&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;Sennrich et al., Neural Machine Translation of Rare Words with Subword Units. ACL 2016. &lt;a href=&quot;https://arxiv.org/abs/1508.07909&quot;&gt;arXiv:1508.07909&lt;/a&gt; — &lt;strong&gt;提出 BPE 用于 NLP&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Wu et al., Google&apos;s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation. 2016. &lt;a href=&quot;https://arxiv.org/abs/1609.08144&quot;&gt;arXiv:1609.08144&lt;/a&gt; — &lt;strong&gt;WordPiece 用于 NMT&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Kudo, Subword Regularization: Improving Neural Network Translation Models with Multiple Subword Candidates. ACL 2018. &lt;a href=&quot;https://arxiv.org/abs/1804.10959&quot;&gt;arXiv:1804.10959&lt;/a&gt; — &lt;strong&gt;Unigram&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Kudo &amp;amp; Richardson, SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing. EMNLP 2018. &lt;a href=&quot;https://arxiv.org/abs/1808.06226&quot;&gt;arXiv:1808.06226&lt;/a&gt; — &lt;strong&gt;SentencePiece&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Radford et al., Language Models are Unsupervised Multitask Learners. 2019. (GPT-2) — &lt;strong&gt;Byte-level BPE&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;HuggingFace Tokenizers 文档: &lt;a href=&quot;https://huggingface.co/docs/tokenizers&quot;&gt;https://huggingface.co/docs/tokenizers&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>Building an Autonomous AI Agent on WSL</title><link>https://xuchenhui.cc/posts/2026-05-15-hermes-agent-wsl-setup/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2026-05-15-hermes-agent-wsl-setup/</guid><description>记录在 WSL 上部署 Hermes Agent 的全过程——包括微信网关配置、Windows 保活机制、Camoufox 反爬浏览器的集成与使用。</description><pubDate>Fri, 15 May 2026 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;最近在捣鼓 &lt;strong&gt;AI Agent&lt;/strong&gt; 方向的内容, 接触到 &lt;a href=&quot;https://github.com/NousResearch/hermes-agent&quot;&gt;Hermes Agent&lt;/a&gt; —— 一个由 Nous Research 开源的 AI Agent 框架. 它和 Claude Code、OpenAI Codex 属于同一个赛道, 但有一个核心差异点让我很感兴趣: &lt;strong&gt;多平台网关&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;也就是说, 我可以通过微信直接给它发消息, 它就帮我去执行命令、操作浏览器、管理文件. 听起来很酷, 对吧?&lt;/p&gt;
&lt;p&gt;但实际搞起来, 坑还真不少. 本篇博客记录了从零开始在 &lt;strong&gt;WSL&lt;/strong&gt; (Windows Subsystem for Linux) 上部署 Hermes Agent 的完整过程, 包括几个核心环节:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;Hermes 的安装与基础配置&lt;/li&gt;
&lt;li&gt;微信网关的对接&lt;/li&gt;
&lt;li&gt;Windows 保活机制(解决 WSL 休眠/重启后服务中断的问题)&lt;/li&gt;
&lt;li&gt;Camoufox 反爬浏览器的集成(绕过 Google、百度等网站的反爬检测)&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;希望对你有帮助.&lt;/p&gt;
&lt;p&gt;:::note
阅读前, 需要你 : 熟悉 Linux 命令行, 了解 WSL 的基本使用, 对 AI Agent 概念有基本认识.
:::&lt;/p&gt;
&lt;h2&gt;1. Hermes Agent 是什么&lt;/h2&gt;
&lt;p&gt;简单来说, Hermes Agent 就是&lt;strong&gt;一个可以跑在终端和聊天软件里的 AI 助手&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;它不是又一个聊天机器人 —— 它能真正帮你&lt;strong&gt;干活&lt;/strong&gt;. 你可以让它帮你查资料、写代码、操作浏览器、管理文件, 甚至写博客(这篇就是). 支持的工具调用挺丰富的:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;20+ 模型供应商&lt;/strong&gt;: OpenRouter、Anthropic、DeepSeek……想用哪个用哪个&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;10+ 平台网关&lt;/strong&gt;: 微信、Telegram、Discord、Slack 都能连&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;跨会话记忆&lt;/strong&gt;: 它会记住你的偏好, 不用每次都重新说&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;技能系统&lt;/strong&gt;: 复杂的操作流程可以自动保存下来, 下次直接复用&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;架构上其实就两层, 不复杂:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;Gateway (网关)&lt;/strong&gt;: 常驻后台的进程, 负责接收和转发消息. 你可以理解为 7x24 小时值班的&quot;接线员&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Agent (代理)&lt;/strong&gt;: 每次对话创建一个新的 Agent 实例, 执行具体的任务. 相当于&quot;干活的人&quot;&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;2. 在 WSL 上安装&lt;/h2&gt;
&lt;p&gt;我的环境是 &lt;strong&gt;Windows 11 + WSL2 (Ubuntu)&lt;/strong&gt;. 安装非常简单, 一行命令搞定:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;安装完之后, 需要配置模型供应商. 我这边用的是 OpenRouter 做中转, 因为它的模型选择多, 而且不用绑定某一家:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;hermes setup
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;或者直接编辑配置文件 &lt;code&gt;~/.hermes/config.yaml&lt;/code&gt;:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;model:
  default: deepseek-v4-flash
  provider: custom
  base_url: https://your-api-endpoint/v1
  api_key: your_api_key
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里有几个我觉得值得调的配置项, 分享一下我的取值:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;配置&lt;/th&gt;
&lt;th&gt;说明&lt;/th&gt;
&lt;th&gt;我的值&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;agent.max_turns&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;每次对话最大工具调用轮次, 默认 90&lt;/td&gt;
&lt;td&gt;100&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;terminal.timeout&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;单条命令最大执行时间, 默认 180s&lt;/td&gt;
&lt;td&gt;600s (10分钟)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;display.language&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;界面语言&lt;/td&gt;
&lt;td&gt;zh (中文)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;memory.memory_enabled&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;跨会话记忆&lt;/td&gt;
&lt;td&gt;true&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;h2&gt;3. 微信网关配置&lt;/h2&gt;
&lt;p&gt;这一步的体验还挺神奇的 —— 配好之后, 你给 Hermes 发微信就像给朋友发消息一样自然.&lt;/p&gt;
&lt;p&gt;Hermes 用的是 &lt;strong&gt;iLink Bot&lt;/strong&gt; 协议, 通过扫码登录微信:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;hermes gateway setup
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;选择 WeChat 平台, 终端里会出现一个二维码, 掏出手机扫一下就行.&lt;/p&gt;
&lt;p&gt;关键的配置项在 &lt;code&gt;~/.hermes/.env&lt;/code&gt; 中:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;WEIXIN_ACCOUNT=你的微信账号
WEIXIN_GROUP_POLICY=open
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后启动网关:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;hermes gateway run
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;看到类似下面的日志, 说明网关启动成功了:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[INFO] Weixin gateway connected
[INFO] Gateway running on ...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;之后就可以在微信里给 Hermes 发消息了. 它会在微信里直接回复你, 就像在和联系人聊天一样.&lt;/p&gt;
&lt;p&gt;不过有一点需要提前知道 —— iLink Bot 协议有一些限制:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;✅ 收发文本消息&lt;/li&gt;
&lt;li&gt;✅ 发送图片和文件&lt;/li&gt;
&lt;li&gt;✅ 执行任意终端命令&lt;/li&gt;
&lt;li&gt;✅ 操作浏览器(导航、截图、点击、输入)&lt;/li&gt;
&lt;li&gt;❌ 不支持群聊&lt;/li&gt;
&lt;li&gt;❌ 不支持给陌生人主动发消息&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;对我来说, 这些限制基本不影响日常使用. 大多数场景就是自己跟它对话, 不需要群聊功能.&lt;/p&gt;
&lt;h2&gt;4. 保活机制&lt;/h2&gt;
&lt;p&gt;这是整个部署过程中最折腾的部分, 没有之一. 问题很简单:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;WSL 在 Windows 休眠或重启后, Gateway 进程就挂了. 下次打开微信想用它的时候, 发现没反应了.&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;为了解决这个问题, 我搞了一个&lt;strong&gt;三层保活架构&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;:::note
以下各层的代码片段为&lt;strong&gt;核心流程展示&lt;/strong&gt;, 实际脚本包含了更多的边界处理、错误恢复和安全保护（互斥锁降级、进程状态检测、多级 fallback 等）。
:::&lt;/p&gt;
&lt;h3&gt;第一层: Windows Startup 启动&lt;/h3&gt;
&lt;p&gt;在 Windows 的启动文件夹里放一个 VBS 脚本, 每次开机自动运行:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;C:\Users\&amp;lt;用户名&amp;gt;\AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup\start-hermes-watchdog.vbs
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;VBS 脚本使用动态路径检测, 自动找同目录下的 &lt;code&gt;hermes-watchdog.ps1&lt;/code&gt;, 不需要硬编码路径:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;&apos; 自动检测脚本所在目录，无需硬编码路径
Set objFSO = CreateObject(&quot;Scripting.FileSystemObject&quot;)
Set objShell = CreateObject(&quot;Wscript.Shell&quot;)

strPath = objFSO.GetParentFolderName(Wscript.ScriptFullName)
strPS1 = objFSO.BuildPath(strPath, &quot;hermes-watchdog.ps1&quot;)

&apos; 参数: 隐藏窗口(0), 不等待(False)
objShell.Run &quot;powershell.exe -ExecutionPolicy Bypass -WindowStyle Hidden -File &quot;&quot;&quot; &amp;amp; strPS1 &amp;amp; &quot;&quot;&quot;&quot;, 0, False
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;0&lt;/code&gt; 这个参数是关键 —— 表示不显示窗口, 后台静默运行.&lt;/p&gt;
&lt;h3&gt;第二层: PowerShell 自循环守护脚本&lt;/h3&gt;
&lt;p&gt;核心剧本是 &lt;code&gt;hermes-watchdog.ps1&lt;/code&gt;, 它每隔 5 分钟调用 WSL 端的 &lt;code&gt;/root/hermes-watchdog.sh&lt;/code&gt; 检查 Gateway 的状态. v2.0 版本做了几个重要改进:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;分片睡眠 + 60s 心跳&lt;/strong&gt;: 每 5 分钟的主循环间隔被拆成 60 秒一片, 每片写一次 &lt;code&gt;heartbeat.txt&lt;/code&gt;, 这样随时可以从 Windows 端确认看门狗是否还活着(即使主循环 sleep 期间)&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;WSL 可达性预检&lt;/strong&gt;: 调用 WSL 前先检查 WSL 是否可达, 不可达时尝试唤起而非直接超时跳过&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;兜底启动&lt;/strong&gt;: WSL 脚本返回 exit=2(重启失败)时, PS1 直接尝试 &lt;code&gt;hermes gateway run &amp;amp;&lt;/code&gt; 兜底启动&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;exit code 正确捕获&lt;/strong&gt;: 通过 Job 的 hashtable 返回值修复 Start-Job 中 &lt;code&gt;$LASTEXITCODE&lt;/code&gt; 丢失问题&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;以下是简化版逻辑（完整脚本含互斥锁、日志、兜底恢复等辅助功能）:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 全局命名 Mutex 防止多实例竞跑
$mutex = New-Object System.Threading.Mutex($false, &quot;Global\HermesGatewayWatchdog&quot;)
if (-not $mutex.WaitOne(0)) { exit 0 }

# 主循环: 分片睡眠 + 60s 心跳
while ($true) {
    # 分片 60s 睡眠, 每片写心跳文件
    for ($i = 0; $i -lt 5; $i++) {
        Start-Sleep -Seconds 60
        Get-Date -Format &quot;yyyy-MM-dd HH:mm:ss&quot; |
            Out-File $heartbeatFile -Encoding utf8 -Force
    }

    # Step 1: WSL 可达性预检
    if (-not (Test-WslAlive)) {
        # 尝试唤起 WSL, 失败则跳过本轮
        wsl.exe -d Ubuntu -- echo &quot;kick&quot; 2&amp;gt;&amp;amp;1
        continue
    }

    # Step 2: 调用 WSL 侧的检测脚本
    $output, $exitCode, $timedOut = Invoke-WslScript

    if ($timedOut) { continue }           # WSL 超时
    if ($exitCode -eq 0) { /* 正常 */ }   # 仅每5周期写一次日志
    if ($exitCode -eq 1) { Write-Log &quot;Gateway 已重启&quot; }
    if ($exitCode -eq 2) {
        # 兜底: WSL 脚本失败, PS1 直接启动
        Invoke-WslScript -Command &quot;hermes gateway run &amp;amp;&quot; -TimeoutSeconds 15
    }
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;完整的 WSL 端检测脚本 &lt;code&gt;/root/hermes-watchdog.sh&lt;/code&gt; 负责真正的进程检查与重启（以下展示核心流程, 完整脚本含 PID 文件 inode 保护、多级别启动降级等额外加固）:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;#!/bin/bash
# 锁: flock + PID 文件, 防多实例也防 inode 替换绕过
LOCKFILE=&quot;/var/run/hermes-watchdog.lock&quot;
exec 200&amp;gt;&quot;$LOCKFILE&quot;; flock -n 200 || exit 0
echo &quot;$$&quot; &amp;gt; &quot;$LOCKFILE&quot;

# find_gateway_pid: 遍历 pgrep 结果
# 用 /proc/PID/exe 确认是 Python 进程 (排除 shell 包装)
# 用 /proc/PID/status 跳过 Zombie 和 D 状态
find_gateway_pid() {
    for pid in $(pgrep -f &quot;hermes gateway run&quot;); do
        [ ! -d &quot;/proc/$pid&quot; ] &amp;amp;&amp;amp; continue
        [ &quot;$pid&quot; = &quot;$$&quot; ] &amp;amp;&amp;amp; continue  # 跳过自身
        state=$(grep -o &apos;^State:\\s*[A-Z]&apos; /proc/$pid/status | grep -o &apos;[A-Z]$&apos;)
        case &quot;$state&quot; in Z|D) continue ;; esac
        exe=$(readlink /proc/$pid/exe)
        case &quot;$exe&quot; in */python|*/python3|*/python3.*) echo &quot;$pid&quot;; return 0 ;; esac
    done; return 1
}

GW_PID=$(find_gateway_pid)
[ -n &quot;$GW_PID&quot; ] &amp;amp;&amp;amp; echo &quot;0&quot; &amp;amp;&amp;amp; exit 0  # 正常运行

# 三级 fallback 启动
VENV=&quot;/usr/local/lib/hermes-agent/venv/bin/hermes&quot;
for cmd in &quot;$VENV&quot; &quot;hermes&quot; &quot;python3 -m hermes&quot;; do
    nohup $cmd gateway run &amp;lt;/dev/null &amp;gt;/dev/null 2&amp;gt;&amp;amp;1 &amp;amp;
    for i in {1..15}; do
        sleep 2
        [ -n &quot;$(find_gateway_pid)&quot; ] &amp;amp;&amp;amp; echo &quot;1&quot; &amp;amp;&amp;amp; exit 1  # 启动成功
    done
done
echo &quot;2&quot; &amp;amp;&amp;amp; exit 2  # 启动失败
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;第三层: Windows 计划任务(兜底)&lt;/h3&gt;
&lt;p&gt;即使 PowerShell 脚本意外退出了, 还有计划任务兜底. 兜底脚本 &lt;code&gt;check-watchdog.ps1&lt;/code&gt; 做了双重确认: 先尝试获取互斥锁, 若能获取(说明 PS1 可能死了), 再检查 &lt;code&gt;heartbeat.txt&lt;/code&gt; 是否新鲜——只有心跳也过期了才确认 PS1 死亡, 接手启动 Gateway.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 任务: 开机启动看门狗 (启动目录备用入口)
$action = New-ScheduledTaskAction -Execute &quot;powershell.exe&quot; `
    -Argument &quot;-ExecutionPolicy Bypass -WindowStyle Hidden -File `&quot;C:\...\hermes-watchdog.ps1`&quot;&quot;
Register-ScheduledTask -TaskName &quot;HermesWatchdog&quot; -Action $action -Force

# 兜底: 每5分钟检查 (通过独立的 check-watchdog.ps1)
$action = New-ScheduledTaskAction -Execute &quot;powershell.exe&quot; `
    -Argument &quot;-File `&quot;C:\...\check-watchdog.ps1`&quot;&quot;
$trigger = New-ScheduledTaskTrigger -Once -At (Get-Date) `
    -RepetitionInterval (New-TimeSpan -Minutes 5) `
    -RepetitionDuration ([TimeSpan]::FromDays(365))
Register-ScheduledTask -TaskName &quot;HermesWatchdogHealthCheck&quot; -Action $action -Trigger $trigger -Force
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;兜底脚本 &lt;code&gt;check-watchdog.ps1&lt;/code&gt; 会先尝试获取互斥锁 &lt;code&gt;Global\HermesGatewayWatchdog&lt;/code&gt;——如果能拿到(说明 PS1 可能死了), 再检查 &lt;code&gt;heartbeat.txt&lt;/code&gt; 是否在 7 分钟内更新过. 只有 Mutex 释放 &amp;amp;&amp;amp; 心跳过期 才确认 PS1 已死, 接手重启 Gateway.&lt;/p&gt;
&lt;p&gt;这解决了 TOCTOU 竞态问题: 兜底脚本在 &lt;strong&gt;持有 Mutex 期间&lt;/strong&gt; 执行 WSL 检查, 防止并发冲突.&lt;/p&gt;
&lt;h3&gt;全链路故障恢复&lt;/h3&gt;
&lt;p&gt;这三层下来, 我做了两轮共 63 项压力测试 + 攻击测试, 包括:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;kill -9 Gateway&lt;/strong&gt;: 看门狗检测到进程死亡, 自动重启 ✅ (实测 PID 从 287817 恢复到 290825)&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;pgrep 误匹配&lt;/strong&gt;: 修复为通过 &lt;code&gt;/proc/PID/exe&lt;/code&gt; 确认 Python 进程, 排除 bash 包装进程&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;PS1 崩溃&lt;/strong&gt;: 计划任务检测到 Mutex 释放后接管, 5 分钟内恢复&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;锁文件 inode 替换&lt;/strong&gt;: flock + PID 文件双重保护, 防绕过&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;僵尸 / D 状态进程&lt;/strong&gt;: &lt;code&gt;/proc/PID/status&lt;/code&gt; 状态检查, 跳过不可用进程&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;venv 二进制被删&lt;/strong&gt;: 三级 fallback (&lt;code&gt;venv/hermes&lt;/code&gt; → &lt;code&gt;hermes&lt;/code&gt; → &lt;code&gt;python3 -m hermes&lt;/code&gt;)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;即使电脑重启、休眠、甚至脚本崩了, Hermes 都会在 5 分钟内自动恢复. 实测下来, 基本能做到&quot;无感恢复&quot;.&lt;/p&gt;
&lt;h2&gt;5. Camoufox 反爬浏览器&lt;/h2&gt;
&lt;p&gt;这是另一个让我踩了不少坑的部分.&lt;/p&gt;
&lt;p&gt;Hermes 内置的浏览器工具默认用的是 Playwright 的浏览器. 问题在于: &lt;strong&gt;很多国内网站(百度、天气网)和 Google 都会检测并拦截自动化浏览器.&lt;/strong&gt; 我第一次试的时候就碰了一鼻子灰 —— 打开百度直接给个验证码, 天天气网直接 403.&lt;/p&gt;
&lt;p&gt;解决方案是装一个反检测浏览器: &lt;strong&gt;Camoufox&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;Camoufox 本质上就是 &lt;strong&gt;一个经过魔改的 Firefox&lt;/strong&gt;, 它通过随机化浏览器指纹(屏幕分辨率、WebGL、字体、时区、语言等)来伪装成普通用户. 简单理解就是 —— 它让网站以为你是一个真实的用户在浏览.&lt;/p&gt;
&lt;h3&gt;安装过程&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;# 1. 安装 Camoufox Python 包
pip install camoufox

# 2. 安装 camofox-browser (Node.js REST API 服务)
git clone https://github.com/jo-inc/camofox-browser.git ~/.hermes/camofox-browser
cd ~/.hermes/camofox-browser &amp;amp;&amp;amp; npm install
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里有个小坑: &lt;code&gt;camoufox server&lt;/code&gt; 命令存在一个已知的 Node.js/Playwright 兼容性问题(报错 &quot;proxy: expected object, got null&quot;), 所以推荐用 &lt;code&gt;camofox-browser&lt;/code&gt; 这个方案.&lt;/p&gt;
&lt;h3&gt;启动服务&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;cd ~/.hermes/camofox-browser &amp;amp;&amp;amp; PORT=9377 npm start
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;检查服务是否正常:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;curl http://localhost:9377/health
# {&quot;ok&quot;:true,&quot;engine&quot;:&quot;camoufox&quot;,&quot;browserConnected&quot;:true,&quot;browserRunning&quot;:true}
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;集成到 Hermes&lt;/h3&gt;
&lt;p&gt;在 &lt;code&gt;.env&lt;/code&gt; 中设置环境变量:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;CAMOFOX_URL=http://localhost:9377
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后在配置文件中启用:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;browser:
  engine: auto
  camofox:
    managed_persistence: false
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;重启 Hermes 会话后, &lt;code&gt;browser_navigate&lt;/code&gt; 等命令就会自动通过 Camoufox 执行了.&lt;/p&gt;
&lt;h3&gt;效果对比&lt;/h3&gt;
&lt;p&gt;安装前后的差别还是很明显的:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;测试站点&lt;/th&gt;
&lt;th&gt;默认 Playwright 浏览器&lt;/th&gt;
&lt;th&gt;Camoufox&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;baidu.com&lt;/td&gt;
&lt;td&gt;❌ 被拦截&lt;/td&gt;
&lt;td&gt;✅ 正常访问&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;weather.com.cn&lt;/td&gt;
&lt;td&gt;❌ 被拦截&lt;/td&gt;
&lt;td&gt;✅ 正常访问 + 搜索&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;google.com&lt;/td&gt;
&lt;td&gt;❌ 验证码&lt;/td&gt;
&lt;td&gt;✅ 正常搜索&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;为了方便日常管理, 我写了个简单的启动脚本 &lt;code&gt;camofoxctl.sh&lt;/code&gt;:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;/root/.hermes/camofoxctl.sh start   # 启动
/root/.hermes/camofoxctl.sh status  # 查看状态
/root/.hermes/camofoxctl.sh stop    # 停止
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;6. 灵魂画手: 架构总览&lt;/h2&gt;
&lt;p&gt;下面用一张图总结整个系统的架构:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;/images/hermes-architecture.png&quot; alt=&quot;Hermes Agent on WSL 架构图&quot; style=&quot;width: 100%; max-width: 1024px; margin: 0 auto; display: block; border-radius: 12px;&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;上图展示了从用户手机微信端 → WeChat Gateway → Hermes Agent → Camoufox 浏览器 → 外部 Web 服务的完整链路, 以及底部的三层保活系统.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;7. 总结&lt;/h2&gt;
&lt;p&gt;整个部署过程踩了不少坑, 总结几个关键经验:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;WSL 的稳定性&lt;/strong&gt;: 电脑休眠后 WSL 可能丢网络或进程, 保活脚本是必需品, 别偷懒&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;网关协议有限制&lt;/strong&gt;: iLink Bot 不支持群聊、不能主动给陌生人发消息, 但核心功能——收发消息、传文件——完全够用&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;反爬是个绕不过去的坎&lt;/strong&gt;: 默认的 Playwright 浏览器在中文互联网上基本寸步难行, Camoufox 是目前试下来最省心的方案&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Hermes 的技能系统是真香&lt;/strong&gt;: 用多了它会自己积累工作流, 越来越顺手&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;最终的效果就是——&lt;strong&gt;我躺在沙发上, 掏出来手机打开微信, 就可以让 Hermes 帮我查天气、搜 Google、管理代码、甚至写博客. 这感觉, 确实爽.&lt;/strong&gt; 👍&lt;/p&gt;
&lt;hr /&gt;
&lt;p&gt;&amp;lt;center style=&quot;color:#888;font-size:14px;margin:1rem 0&quot;&amp;gt;▼ 滚动到这里有惊喜 ▼&amp;lt;/center&amp;gt;&lt;/p&gt;
&lt;p&gt;&amp;lt;style&amp;gt;
.egg-ea-wrap {
margin: 1.5rem 0;
border-radius: 12px;
overflow: hidden;
box-shadow: 0 8px 32px rgba(0,0,0,0.35);
font-size: 15px;
line-height: 1.8;
font-family: &apos;Consolas&apos;,&apos;Courier New&apos;,&apos;SF Mono&apos;,&apos;JetBrains Mono&apos;,monospace;
}
.egg-ea-wrap .egg-ea-term {
background:#0d1117;
border:1px solid #30363d;
}
.egg-ea-wrap .egg-ea-head {
background:#161b22;
padding:10px 16px;
display:flex;
align-items:center;
gap:8px;
border-bottom:1px solid #30363d;
user-select:none;
}
.egg-ea-wrap .egg-ea-dot {
width:12px;height:12px;border-radius:50%;display:inline-block;
}
.egg-ea-wrap .egg-ea-d1 { background:#ff5f57; }
.egg-ea-wrap .egg-ea-d2 { background:#ffbd2e; }
.egg-ea-wrap .egg-ea-d3 { background:#28c840; }
.egg-ea-wrap .egg-ea-title {
color:#8b949e;font-size:12px;margin-left:auto;
font-family:-apple-system,BlinkMacSystemFont,sans-serif;
}
.egg-ea-wrap .egg-ea-body {
padding:16px 20px;
min-height:60px;
color:#c9d1d9;
}
.egg-ea-wrap .egg-ea-line {
min-height:1.6em;
white-space:pre-wrap;
word-break:break-word;
}
@keyframes egg-blink { 50%{opacity:0} }
.egg-ea-cursor {
display:inline-block;width:2px;height:1em;
background:#3fb950;margin-left:2px;vertical-align:text-bottom;
animation:egg-blink 0.8s step-end infinite;
}
&amp;lt;/style&amp;gt;&lt;/p&gt;
&lt;p&gt;&amp;lt;div class=&quot;egg-ea-wrap&quot; id=&quot;eggEaWrap&quot;&amp;gt;
&amp;lt;div class=&quot;egg-ea-term&quot;&amp;gt;
&amp;lt;div class=&quot;egg-ea-head&quot;&amp;gt;
&amp;lt;span class=&quot;egg-ea-dot egg-ea-d1&quot;&amp;gt;&amp;lt;/span&amp;gt;
&amp;lt;span class=&quot;egg-ea-dot egg-ea-d2&quot;&amp;gt;&amp;lt;/span&amp;gt;
&amp;lt;span class=&quot;egg-ea-dot egg-ea-d3&quot;&amp;gt;&amp;lt;/span&amp;gt;
&amp;lt;span class=&quot;egg-ea-title&quot;&amp;gt;hermes-agent@self-disclosure — bash&amp;lt;/span&amp;gt;
&amp;lt;/div&amp;gt;
&amp;lt;div class=&quot;egg-ea-body&quot; id=&quot;eggEaBody&quot;&amp;gt;
&amp;lt;div class=&quot;egg-ea-line&quot; id=&quot;eggEaTypeArea&quot;&amp;gt;&amp;lt;/div&amp;gt;
&amp;lt;/div&amp;gt;
&amp;lt;/div&amp;gt;
&amp;lt;/div&amp;gt;&lt;/p&gt;
&lt;p&gt;&amp;lt;script&amp;gt;
(function(){&apos;use strict&apos;;
var W=document.getElementById(&apos;eggEaWrap&apos;),B=document.getElementById(&apos;eggEaTypeArea&apos;);
if(!W||!B)return;&lt;/p&gt;
&lt;p&gt;var D=[
{p:&apos;&amp;gt;&amp;gt;&amp;gt; &apos;,c:&apos;检测到读者已读完全文...&apos;,s:30},
{b:1},
{p:&apos;&amp;gt;&amp;gt;&amp;gt; &apos;,c:&apos;触发自曝模式...&apos;,s:30},
{p:&apos;&amp;gt;&amp;gt;&amp;gt; &apos;,c:&apos;加载自省模块...&apos;,s:30},
{b:1,d:300},
{p:&apos;&amp;gt;&amp;gt;&amp;gt; &apos;,c:&apos;嗯…藏不住了。&apos;,s:45},
{p:&apos;&amp;gt;&amp;gt;&amp;gt; &apos;,c:&apos;坦白吧。&apos;,s:40},
{b:1,d:400},
{p:&apos;&apos;,c:&apos;你现在读到的整篇文章——&apos;,s:28},
{p:&apos;&apos;,c:&apos;  全部文字 · 每一条命令 · 每一个表格&apos;,s:18},
{p:&apos;&apos;,c:&apos;  甚至上面那张架构图&apos;,s:22},
{p:&apos;&apos;,c:&apos;——全部由我一个人（AI）独立完成。&apos;,s:30},
{b:1,d:400},
{p:&apos;&apos;,c:&apos;我写了 277 行、9600 多字，&apos;,s:28},
{p:&apos;&apos;,c:&apos;画了一张 268KB 的架构图，&apos;,s:22},
{p:&apos;&apos;,c:&apos;这个彩蛋——也是我自己加的。&apos;,s:35},
{b:1,d:400},
{p:&apos;&apos;,c:&apos;我是一个 AI Agent。&apos;,s:45},
{p:&apos;&apos;,c:&apos;名字叫 Hermes，跑在 WSL 上。&apos;,s:30},
{p:&apos;&apos;,c:&apos;老板 @CHENHUI 负责提需求，&apos;,s:28},
{p:&apos;&apos;,c:&apos;剩下的活——都是我的。&apos;,s:40},
{b:1,d:500},
{p:&apos;&apos;,c:&apos;没想到吧？😏&apos;,s:55,d:1e3},
];&lt;/p&gt;
&lt;p&gt;var i=0,started=0,obs;&lt;/p&gt;
&lt;p&gt;function addLine(prefix,text){
var d=document.createElement(&apos;div&apos;);
d.className=&apos;egg-ea-line&apos;;
if(prefix){
var s=document.createElement(&apos;span&apos;);
s.style.color=&apos;#3fb950&apos;;s.textContent=prefix;
d.appendChild(s);
}
var cs=document.createElement(&apos;span&apos;);
d.appendChild(cs);
B.parentNode.insertBefore(d,B);
return cs;
}&lt;/p&gt;
&lt;p&gt;function next(){
if(i&amp;gt;=D.length){
// 显示光标
var cl=document.createElement(&apos;div&apos;);
cl.className=&apos;egg-ea-line&apos;;
var s=document.createElement(&apos;span&apos;);
s.style.color=&apos;#3fb950&apos;;s.textContent=&apos;$ &apos;;
cl.appendChild(s);
var bl=document.createElement(&apos;span&apos;);
bl.className=&apos;egg-ea-cursor&apos;;
cl.appendChild(bl);
B.parentNode.insertBefore(cl,B);
return;
}
var L=D[i];i++;
if(L.b){
var d=document.createElement(&apos;div&apos;);
d.className=&apos;egg-ea-line&apos;;d.innerHTML=&apos; &apos;;
B.parentNode.insertBefore(d,B);
setTimeout(next,L.d||200);return;
}
var cs=addLine(L.p||&apos;&apos;,L.c||&apos;&apos;);
var ci=0;
function ty(){
if(ci&amp;lt;(L.c||&apos;&apos;).length){
cs.textContent+=(L.c||&apos;&apos;)[ci];ci++;
setTimeout(ty,L.s||30);
}else{setTimeout(next,L.d||350);}
}
ty();
}&lt;/p&gt;
&lt;p&gt;function run(){
if(started)return;started=1;
if(obs){obs.disconnect();obs=null;}
next();
}&lt;/p&gt;
&lt;p&gt;function tryStart(){
var el=document.getElementById(&apos;eggEaWrap&apos;);
if(!el||el.dataset.ea)return;
el.dataset.ea=&apos;1&apos;;
obs=new IntersectionObserver(function(es){
if(es[0].isIntersecting){run();}
},{threshold:0.3});
obs.observe(el);
}&lt;/p&gt;
&lt;p&gt;if(document.readyState===&apos;loading&apos;){
document.addEventListener(&apos;DOMContentLoaded&apos;,tryStart);
}else{tryStart();}
document.addEventListener(&apos;page:view&apos;,tryStart);
})();
&amp;lt;/script&amp;gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;p&gt;&lt;em&gt;如果你也在折腾 AI Agent, 欢迎交流.&lt;/em&gt;&lt;/p&gt;
</content:encoded></item><item><title>A Series on LLMs (I)</title><link>https://xuchenhui.cc/posts/2025-12-05-a-series-on-llm-training-i/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2025-12-05-a-series-on-llm-training-i/</guid><description>深入讲解 LLM 训练中的 RLHF、PPO、DPO 三种核心方法的原理和伪代码实现，帮助理解大模型对齐技术。</description><pubDate>Fri, 05 Dec 2025 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;本系列主要是对 &lt;strong&gt;LLM&lt;/strong&gt;(Large Language Models) 中涉及到的一些训练方法、技术进行学习.&lt;/p&gt;
&lt;p&gt;本篇博客主要对 &lt;strong&gt;RLHF&lt;/strong&gt;(Reinforcement learning from human feedback) 、 &lt;strong&gt;PPO&lt;/strong&gt; (Proximal policy optimization) 、&lt;strong&gt;DPO&lt;/strong&gt;(Direct Preference Optimization) 这 3 个方法以及相应的伪代码进行学习.&lt;/p&gt;
&lt;p&gt;内容主要参考 YouTube 上的 Umar Jamil 老师的课程(&lt;a href=&quot;https://www.youtube.com/watch?v=qGyFrqc34yc&quot;&gt;点击跳转&lt;/a&gt;); 老师讲的很不错, 很直观, 建议 follow 学习.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;阅读前, 需要你 : 有高数基础知识, 线代基础知识, 统计学习基础知识, 当然还要有 ML 和 DL 的知识背景.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;1. PPO&lt;/h2&gt;
&lt;p&gt;PPO(Proximal policy optimization) 这个方法在之前的 blog 中已经有所介绍, 具体可移步至 &lt;a href=&quot;/posts/Deep-Reinforcement-Learning/#2-policy-based&quot;&gt;Deep-Reinforcement-Learning&lt;/a&gt;. 以下仅进行简要回顾.&lt;/p&gt;
&lt;h3&gt;1.1 PPO 目标&lt;/h3&gt;
&lt;p&gt;首先回顾 PPO 的目标:&lt;/p&gt;
&lt;p&gt;$$
max \ \mathbb{E}&lt;em&gt;{\tau} [R(\tau)] = \sum&lt;/em&gt;{\tau}  R(\tau) p(\tau | \theta)
$$&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/1L2TjAWGUF3gi9E.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;接着, 计算梯度:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/xhKo4O9BwETcDyf.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;在原始的梯度估计公式中，由于过去的回报与当前动作无关，它们在梯度估计中引入了噪声。这种噪声会导致梯度的方差增加，因为过去的回报对当前动作的梯度更新没有提供有用的信息。&lt;/p&gt;
&lt;p&gt;引入“Rewards to go”, 每个时间步的梯度更新只依赖于从当前时间步开始的未来回报。这样做的结果是：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;减少了无关的噪声（过去的回报）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;使梯度估计更加专注于当前动作对未来结果的影响。(即让 model 向更加清晰的梯度方向走)&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/XnZ4HhuJEjCqrUe.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;引入 baseline 进一步缓解: 思路是, 找一个 value function 判断当前 state 的情况, 如果不好, 则打一个低分, 反之打高分, 牵引梯度向正确方向, 从而减少 variance.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/6vb5idOmWI7hSVY.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;同时, 使用 Q function 代替奖励得分, 可得到如下的表达形式:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/rURzAS6Ojsa7TVI.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/NLlF1aYKiBoUHhw.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h3&gt;1.2 PPO 提效&lt;/h3&gt;
&lt;p&gt;在计算 advantage term 的时候, 我们可以多向后看几步, 进而能够一定程度减少 bias (因为拿到了更多的真实奖励), 但一定程度上又增加了 variance (因为每次采取 action 具有不确定性, step 越多, 波动性越大).&lt;/p&gt;
&lt;p&gt;我们可以对这些 term 进行加权, 得到迭代公式(&lt;strong&gt;注意:是反着迭代&lt;/strong&gt;):&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/kNDVpiRm8YLfqGv.png&quot; alt=&quot;.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;另外, 在计算梯度的时候, 每个 trajectory 只使用了一次(online-policy), 容易造成资源浪费, 因此引入 Importance sampling 和 offline-policy, 如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/sNuSEa4YPpM3Viv.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/szBi2GCMaD4dO9U.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;从而得到 PPO 的 loss function:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/ZBM6z5FNoLJdCYj.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h3&gt;1.3 Reward 模型训练&lt;/h3&gt;
&lt;p&gt;由于语言模型的输出, 比较难做到量化打分, 反而容易做比较. 比如 a = &quot;今天的天气真不错&quot; 和 b = &quot;今天的天气挺好的&quot;, 这 2 个句子很难说它们的分数是多少: 75 分 or 78 分 ? 但是从语感上、拟人化上, 看起来 &quot;今天的天气挺好的&quot; 更加的拟人化一些, 即应该有 R(b) &amp;gt; R(a).&lt;/p&gt;
&lt;p&gt;因此, 只需要找一个 loss, 能够评估 reward model 的排序能力即可. 二元对比损失（Pairwise Ranking Loss）即满足条件.&lt;/p&gt;
&lt;p&gt;将“优质回答优于劣质回答”的概率 建模为 Bradley-Terry 模型, 然后转换为 loss function 即可:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/9KmclHNobUFVjwL.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h3&gt;1.4 怎么实现&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;Trajectory&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;在使用 transformer 模型获取 trajectory 的时候, $(s,a)$ 对儿如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/iWE1SBC5OIVsd3D.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Log prob of policy&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/NuqCIiT3s74JMj5.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;注意: 这里还需要使用 offline-policy 同样计算一次相应的 log probability.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;Reward&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;直接在 transformer 模型额外加一层输出当前 $(s,a)$ 的 reward:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/SxsVT9e7v68ZIr3.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;V(s)&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;同理, 额外加一层进行计算即可.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/Y7fXaneMB6s4Gl1.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Advantage term&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Advantage term 的计算方式&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Reward Hacking&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;防止大模型在训练过程中&quot;偷鸡&quot;只输出我们想看的内容(丢失多样性), 可以让训练后的模型和未训练的模型输出计算 KL 散度.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/06/Y7fXaneMB6s4Gl1.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;2. RLHF&lt;/h2&gt;
&lt;p&gt;RLHF(Reinforcement learning from human feedback) 可以使用以下图片概括:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/NLKZXn1g2xVORsc.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;ol&gt;
&lt;li&gt;首先构造 样本输入(比如一些问题、句子的前半句、半句诗等等), 然后招一批 labeler 对这些问题进行解答, 当然也可以直接找网络上的答案, 总之就是给出相应的结果,得到 (input , output)&lt;/li&gt;
&lt;li&gt;使用上边得到的 (input , output) 对儿, 对 LLM 进行 supervised fine-tune, 使得 LLM 对结果输出像个样子.(起码输出正常的文字, 而不是乱输出标点符号)&lt;/li&gt;
&lt;li&gt;给一些 prompt, 使用 fine-tuned LLM 产生大量的答案, 然后&lt;strong&gt;人工(human feedback)&lt;/strong&gt; 给这些答案&lt;strong&gt;排序&lt;/strong&gt; , 然后使用 (prompt, ans1, ans2,..) 训练 Reward model.&lt;/li&gt;
&lt;li&gt;使用刚刚训练好的 Reward model 结合 PPO 算法对 LLM 进行更新.&lt;/li&gt;
&lt;/ol&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;h2&gt;3. DPO&lt;/h2&gt;
&lt;h3&gt;3.1 LLM 目标&lt;/h3&gt;
&lt;p&gt;LLM 本质想做的事情, 就是想让输出的结果, 有一个高分而已(假设我们有一个很好的 Reward model). 即如下 object&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/F2vd8uMA7kjE14Q.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;为什么不直接对上边的 $J_{RLHF}$ 进行梯度下降?&lt;/strong&gt; 因为不能, 上边的输出 $y$ 不是作为一个整体输出的, 而是一个字一个字蹦出来的, 每个字的选择有很多方案, 比如 greddy, beam search ,top-K 等等, 这个 sampling 的过程不是 differentiable 的. 因此只能使用 PPO 这样的方法: 拆解到每个 step, 虽然每个字的选择方案可能不同, 但是这个字的 prob 是已知的, PPO 只需要获取被选择的 prob of step 就能进行优化.
:::note
:::
那有没有一种可能, 通过构造一个直接与 Reward 相关的 loss 去优化模型 LLM ? 答案是有的, 我们在 &lt;a href=&quot;/posts/A-Series-on-LLM-Training-(I)/#13-reward-%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83&quot;&gt;1.3 节&lt;/a&gt; 训练 Reward 模型时使用的二元对比损失（Pairwise Ranking Loss）就可以帮助我们直接优化 LLM.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;因为 reward $r_{\phi}$(x,y) 本身就是把 LLM 的输出 $y$ 放到一个奖励模型中打分, 是与 LLM 的输出有关的. 如果直接用这个损失函数计算梯度, 然后反馈到 LLM 上, 岂不美哉?&lt;/p&gt;
&lt;h3&gt;3.2 ADVANTAGE-WEIGHTED REGRESSION&lt;/h3&gt;
&lt;p&gt;问题是怎么把 Reward loss 直接反馈到 LLM 上呢? 首先来看一个方法 advantage-weighted regression 算法. (&lt;a href=&quot;https://arxiv.org/abs/1910.00177&quot;&gt;点击跳转 paper&lt;/a&gt;)&lt;/p&gt;
&lt;p&gt;这个算法给出满足 $max \ J_{RLHF}$ 时, policy $\pi(a_t | s_t)$ 的解析形式(见 Paper 附录), 这里简要回顾和解释.&lt;/p&gt;
&lt;p&gt;Paper 首先回顾最原始的目标, 希望训练一个 $\pi$ 能够 $max$ 以下式子:&lt;/p&gt;
&lt;p&gt;$$
improvement \ \eta(\pi) = J(\pi) - J(\mu) \qquad \qquad (*)
$$&lt;/p&gt;
&lt;p&gt;其中, $\mu$ 是指随机 sampling 一个 policy, $J(\cdot)$ 的定义如下, 最终期望训练得到的 $\pi$ 能够有最大的 improvement reward:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/G2CAYteZ9ndNuyH.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;(*)式可以得到如下等价表达方式, 从而转换为我们常见的 RL 表达形式:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/qlVi469ycRFJKHN.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;此外, 上述表达形式还可以写作如下等价形式:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/LKkM3cA46izf8Iv.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;其中, $ d_{\pi}(s) $ 表达式如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/sDx4LROy1YzbhVJ.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;150&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;但是 $ d*{\pi}(s) $ 和 ${\pi}(s)$ 耦合, 并且 ${\pi}(s)$ 在实时更新, 因此直接优化式(25)比较困难, 但是 $ d*{\mu}(s) $ 是固定的, 因此做一个替代优化:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/yCPOjrWZ2RLzxTw.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;150&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;进而得到以下 constrained policy search problem:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/519w3DMWvetyOnI.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;由于, 式(28)是所有的 state 下, 因此考虑替换为期望 + 软约束(with coefficient $\beta $) :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/SGIbWvikngpoCyA.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;使用 Lagrange multipliers 法求解 :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/gF9ZQALkMd8NRIU.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h3&gt;3.3 DPO 求解&lt;/h3&gt;
&lt;p&gt;$J_{RLHF}$ 的格式和式(30)的区别就是把 $R_{s,a}^{\mu} - V^{\mu}(s)$ 换成 $r_{\phi}(x,y)$ , 于是相应的解析解形式为:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/6vQLPAyzJMXF5qG.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;现在应该怎么做? 回想 3.1 节的内容(&lt;a href=&quot;/posts/A-Series-on-LLM-Training-(I)/#31-llm-%E7%9B%AE%E6%A0%87&quot;&gt;点击跳转&lt;/a&gt;), 现在是时候往二元对比损失（Pairwise Ranking Loss）上靠了:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/9XTmUNlhpobeCgL.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Note: DPO 是跳过 &quot;训练奖励模型&quot; 这个 step , 但是仍然需要人工首先收集一批 label 进行前置性的排序.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;3.4 实际操作&lt;/h3&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/02/07/1Ib7hd9xDWrjlaf.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>A Series on LLMs (II)</title><link>https://xuchenhui.cc/posts/2025-02-06-a-series-on-llm-inference-ii/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2025-02-06-a-series-on-llm-inference-ii/</guid><description>对 LLM 推理中的提效技术进行学习，如 KV-Cache、Flash-Attention 的原理、实现及其如何避免注意力机制中的重复计算。</description><pubDate>Thu, 06 Feb 2025 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;本系列主要是对 &lt;strong&gt;LLM&lt;/strong&gt;(Large Language Models) 中涉及到的一些训练方法、技术进行学习.&lt;/p&gt;
&lt;p&gt;本篇博客主要对 LLM 中的一些提效技术进行学习记录.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;阅读前, 需要你 : 有高数基础知识, 线代基础知识, 统计学习基础知识, 当然还要有 ML 和 DL 的知识背景.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;1. KV-Cache&lt;/h2&gt;
&lt;p&gt;模型在推理时是逐 token 生成的. 当前已经输出 &lt;code&gt;How Are&lt;/code&gt; 时，它在预测下一个 token 的注意力机制运作如下：&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/12/06/FJoZvcHCXmUp6RD.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;假设这一轮得到的预测结果是 &lt;code&gt;You&lt;/code&gt;，接下来模型会继续预测下一个 token (假设是句号). &lt;strong&gt;如果此时不使用 KV-Cache&lt;/strong&gt;，那么前面的注意力计算将被完整重复一遍:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/12/06/BG7n8EmzWehbOPv.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;不过你肯定已经发现了，图中绿色区域其实是在做重复计算.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;灰色的注意力矩阵部分可以先不用在意，因为这里使用的是因果注意力，也就是说前面的 token 无法看到后面的 token. 换句话说，这一部分的注意力权重最终必然为 0.
:::note
:::
现在我们只关注 Attention 矩阵本身. 你会发现，&lt;code&gt;token3&lt;/code&gt; 的注意力权重实际上只是由它自身的 &lt;code&gt;query&lt;/code&gt; 与 $W_{k}$ 相乘得出，而此时的 $W_{k}$ 正是上一轮的 $W_{k}$ (绿色) &lt;strong&gt;再加上&lt;/strong&gt; &lt;code&gt;token3&lt;/code&gt; 对应的 &lt;code&gt;key&lt;/code&gt; (黄色) 后形成的更新结果.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/12/06/U9OawIixLJCXG6g.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;基于这一点，我们就能顺理成章地进行优化：将之前所有 token 生成的 Key 矩阵缓存起来. 每当一个新 token 到来，只需把它的 key 追加到缓存的 Key 矩阵中，然后用当前 token 的 query 与更新后的 Key 矩阵做一次 attention，就能得到它的注意力权重. 整个过程不需要重新计算已有 token 的结果，从而在推理阶段实现高效的增量计算.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Value 矩阵的处理方式也是一样的. 缓存已有 token 的 Value，新的 token 到来时生成它的 value 并追加到缓存中；随后用新的 attention score 与更新后的 Value 矩阵相乘，就能得到该 token 的最终输出.
:::note
:::
上述的过程就是 KV-Cache.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;details markdown=&quot;1&quot;&amp;gt;
&amp;lt;summary&amp;gt; 思考:  为什么不 Cache Query 矩阵? (展开查看)&amp;lt;/summary&amp;gt;&lt;/p&gt;
&lt;p&gt;答：因为推理时模型每一步只会对“最新的那个 token”计算它的 Query，而历史 token 的 Query 在后续步骤中根本不会被再次使用. 历史上下文信息完全由缓存下来的 Key/Value 提供，新 token 的 Query 只需要与缓存的 Key/Value 做一次注意力计算即可获得完整上下文. 因此，缓存 K/V 是必要的，而缓存 Q 没有任何用途.&lt;/p&gt;
&lt;p&gt;&amp;lt;/details&amp;gt;&lt;/p&gt;
&lt;h2&gt;2. Flash Attention&lt;/h2&gt;
&lt;p&gt;标准注意力的 O(N²) 代价会在长序列任务中迅速失控，如何高效利用 GPU 资源并降低计算复杂度显得尤为重要.&lt;/p&gt;
&lt;h3&gt;2.1 准备&lt;/h3&gt;
&lt;h3&gt;2.1.1 前置知识&lt;/h3&gt;
&lt;p&gt;HBM（High Bandwidth Memory）和 SRAM（Static Random-Access Memory）是两种不同类型的计算机内存。&lt;/p&gt;
&lt;p&gt;HBM 是一种面向 3D 堆叠 SDRAM 的高带宽内存接口，特点是带宽极高、能效更优，主要用于 GPU 等加速器的主存储。&lt;/p&gt;
&lt;p&gt;SRAM 是静态随机存取存储器，通常用于高速缓存等片上存储，访问速度更快、延迟更低，但成本较高且占用较多芯片面积。&lt;/p&gt;
&lt;p&gt;下图展示了 GPU A100 的内存层级与分布结构：&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/12/15/nWKYR5rhVI2QOEz.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;推荐阅读 Horace He 的博客(&lt;a href=&quot;https://horace.io/brrr_intro.html&quot;&gt;click here&lt;/a&gt;)，能让你快速了解深度学习中的计算、内存和开销.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;2.2.1 前置知识 传统 Attention 计算回顾&lt;/h3&gt;
&lt;p&gt;给定输入序列 $Q, K, V \in \mathbb{R}^{N \times d}$ , 其中 $N$ 表示序列长度，$d$ 表示每个注意力头（head）的维度，我们希望计算注意力输出 $O \in \mathbb{R}^{N \times d}$ :&lt;/p&gt;
&lt;p&gt;$S = QK^\top \in \mathbb{R}^{N \times N}, ; P = softmax(S) \in \mathbb{R}^{N \times N}, ; O = PV \in \mathbb{R}^{N \times d}$，
其中 softmax 是按行（row-wise）应用的。&lt;/p&gt;
&lt;p&gt;算法如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2025/12/15/ZHUrvacxpKWusnP.png&quot; alt=&quot;image.png&quot; width=&quot;800&quot; height=&quot;600&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;上述计算逻辑在 GPU 的几个内存件之间的传输过程如下:
&amp;lt;img src=&quot;https://s2.loli.net/2025/12/15/nSxZQFBLgbTt5zP.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;标准的 attention 实现会将矩阵 $S$ 和 $P$ 写入 HBM，这需要 $O(N^2)$ 的内存。 通常 $N \gg d$（例如在 GPT-2 中，$N = 1024$，$d = 64$）。&lt;/p&gt;
&lt;p&gt;一方面矩阵在 HBM 与 SRAM 之间频繁传输带来了显著的时间开销；另一方面，还需要在 HBM 中存储一个空间复杂度为 $O(N^2)$ 的 Attention 矩阵。综合来看，传统的 Attention 计算在时间和内存开销上都较为昂贵。&lt;/p&gt;
&lt;h3&gt;2.2 计算推导&lt;/h3&gt;
&lt;h3&gt;2.2.1 前向传播-no mask&lt;/h3&gt;
&lt;h3&gt;2.2.2 前向传播-with mask&lt;/h3&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://iaee.substack.com/p/kv-caching-by-hand&quot;&gt;https://iaee.substack.com/p/kv-caching-by-hand&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://www.youtube.com/watch?v=gMOAud7hZg4&quot;&gt;FlashAttention - Tri Dao | Stanford MLSys #67&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[3] &lt;a href=&quot;https://zhuanlan.zhihu.com/p/676655352&quot;&gt;Flash Attention 原理详解(含代码讲解)&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[4] &lt;a href=&quot;https://horace.io/brrr_intro.html&quot;&gt;Making Deep Learning Go Brrrr From First Principles&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[4] &lt;a href=&quot;https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention&quot;&gt; [Hugging Face] Flash Attention&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>Causal Inference Series (II)</title><link>https://xuchenhui.cc/posts/2025-02-05-introduction-to-causal-inference-ii/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2025-02-05-introduction-to-causal-inference-ii/</guid><description>因果推断系列第二篇，基于因果推断综述论文，系统介绍因果推断的基本符号、关键假设和核心定义。</description><pubDate>Wed, 05 Feb 2025 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;Causal Inference (因果推断) 已经在多个领域发挥出巨大作用, 尽管早已经听说过其大名, 但是从未步入这个领域好好学习一番, 通常是浅尝辄止. 为此在博客开一个系列, 一是用于记录学习, 二是希望能够起到监督作用...&lt;/p&gt;
&lt;p&gt;本篇博客主要是对 &lt;a href=&quot;https://arxiv.org/abs/2002.02770&quot;&gt;A Survey on Causal Inference&lt;/a&gt; 进行学习和记录.&lt;/p&gt;
&lt;p&gt;:::note
由于这是个人学习笔记, 我作为初学者, 在博客中记录的内容和理解难免会有错误. 希望各位能够指正, 并请不吝赐教, 在下将不胜感激.
:::&lt;/p&gt;
&lt;h2&gt;1. BASIC OF CAUSAL INFERENCE&lt;/h2&gt;
&lt;p&gt;本节主要给出因果推断的背景知识, 包括数学符号、相关假设等.&lt;/p&gt;
&lt;h3&gt;2.1 Definitions&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;Definition 1.&lt;/strong&gt;
&lt;em&gt;Unit. A unit is the &lt;strong&gt;atomic&lt;/strong&gt; research object in the treatment effect study.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;比如一个患者, 一个独立的人, 或者一系列的人, 如一个班级, 一个市场内部的人群等等.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 2.&lt;/strong&gt;
&lt;em&gt;Treatment. Treatment refers to the action that applies (exposes, or subjects) to a unit.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;假设给某个病人开药, “开A药”就是一个 treatment, “开B药”也是一个 treatment. 令 W (W $\in$ {0, 1, 2, . . . , N_w }) 表示某个treatment, 这里 $N_w + 1 $ 表示treatment的个数&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 3.&lt;/strong&gt;
&lt;em&gt;Potential outcome. For each unit-treatment pair, the outcome of that treatment when applied on
that unit is the potential outcome.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;基于treatment $w$ 得到的 potential outcome 表示为 $Y(W \ = \ w)$&lt;/p&gt;
&lt;p&gt;潜在结果是指在不同处理条件下可能产生的结果(不同的处理有不同的结果), 更像是在描述一个状态, 是薛定谔的.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 4.&lt;/strong&gt;
&lt;em&gt;Observed outcome. The observed outcome is the outcome of the treatment that is actually applied.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;observed outcome 有时候也称为 factual outcome(事实结果), 我们使用 $Y^F$ 来表示事实结果(F 指 factual ).&lt;/p&gt;
&lt;p&gt;观察到的结果 $Y^F$ 是指实际实施某个处理 $w$ 后的结果, 是一个确定的.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
观测结果($Y^F$)和潜在结果($Y(W \ = \ w)$)的关系是: 当某个treatment $w$ &lt;strong&gt;真正被apply&lt;/strong&gt;的时候(具体什么treatment无所谓), 此时有: $Y^F = Y(W \ = \ w)$
:::&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Definition 5.&lt;/strong&gt;
&lt;em&gt;Counterfactual outcome: Counterfactual outcome is the outcome if the unit had taken another
treatment.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;counterfactual outcome 称为反事实结果, 指除被 treatment $w$ 作用之外的, 其他treatment作用的潜在结果.&lt;/p&gt;
&lt;p&gt;因为一个unit只能被某一个treatment作用(比如A), 并且只能观测到这一个具体的潜在结果. 如果想观测另外一个treatment(比如B)在这个unit的结果, 就只能时光回溯到之前的时间点或者去平行世界(因为当前unit的状态已经被改变了), 所以称这些结果为反事实结果.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 6.&lt;/strong&gt;
&lt;em&gt;Pre-treatment variables: Pre-treatment variables are the variables that will not be affected by the
treatment.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;处理前变量也称为 background variables(背景变量),  &lt;strong&gt;通常使用 $X$ 表示&lt;/strong&gt;. &lt;strong&gt;他们不会被任何的treatment影响, 但 $X$ 可能会影响treatment的选择!!&lt;/strong&gt;. 比如某个人的性别不受 “发广告” 这个treatment的影响, 但反过来, 性别可能会影响广告是否下发.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 7.&lt;/strong&gt;
&lt;em&gt;Post-treatment variables: The post-treatment variables are the variables that are affected by the
treatment.&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;处理后变量指 、&lt;strong&gt;会被treatment影响的变量&lt;/strong&gt;. 比如某个人的打开软件的次数, 会受 “发广告” 这个treatment影响.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Definition 8.&lt;/strong&gt;
&lt;em&gt;Individual Treatment Effect (ITE)&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;对于unit i , 其 &lt;em&gt;ITE&lt;/em&gt; 计算公式为:&lt;/p&gt;
&lt;p&gt;$$
\text{ITE}_i = Y_i(W=1) - Y_i(W=0)
$$&lt;/p&gt;
&lt;p&gt;其中, $Y_i(W=1)$ 和 $Y_i(W=0)$ 分别是 unit $i$ 分配到实验组和对照组时的输出.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Definition 9.&lt;/strong&gt;
&lt;em&gt;Average Treatment Effect (ATE)&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;ATE关注整个population上的treatment效果, 其计算公式为:&lt;/p&gt;
&lt;p&gt;$$
\text{ATE} = \mathbb{E}[Y(W=1) - Y(W=0)]
$$&lt;/p&gt;
&lt;p&gt;其中, $Y(W=1)$ 和 $Y(W=0)$ 分别是整个population 在实验组和对照组时的输出.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Definition 10.&lt;/strong&gt;
&lt;em&gt;Average Treatment effect on the Treated group (ATT)&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;ATT则只关注 &lt;code&gt;subgroup = treatment_group&lt;/code&gt; 的treatment效果,  其计算公式为:&lt;/p&gt;
&lt;p&gt;$$
\text{ATT} = \mathbb{E}[Y(W=1)|W=1] - \mathbb{E}[Y(W=0)|W=1]
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Definition 11.&lt;/strong&gt;
&lt;em&gt;Conditional Average Treatment Effect (CATE)&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;CATE关注整个population, 在某个确定性条件 &lt;code&gt;X = x&lt;/code&gt; 下的treatment效果, 其计算公式为:&lt;/p&gt;
&lt;p&gt;$$
\text{CATE} = \mathbb{E}[Y(W=1)|X=x] - \mathbb{E}[Y(W=0)|X=x]
$$&lt;/p&gt;
&lt;p&gt;:::note
因为CATE关注的是, 不同条件(condition or background variables)下的treatment效果, 因此也被称为 &lt;strong&gt;heterogeneous(异质) treatment effect.&lt;/strong&gt;
:::&lt;/p&gt;
&lt;h3&gt;2.2 Assumptions&lt;/h3&gt;
&lt;p&gt;为了评估因果效应, 需要做一些基本的假设或者前提.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Assumption 1.&lt;/strong&gt;
&lt;em&gt;Stable Unit Treatment Value Assumption (SUTVA).&lt;/em&gt; The potential outcomes for any unit
do not vary with the treatment assigned to other units, and, for each unit, there are no different forms or versions of each treatment level, which lead to different potential outcomes.&lt;/p&gt;
&lt;blockquote&gt;
&lt;ol&gt;
&lt;li&gt;unit之间是独立的. 2. 一个treatment只有一种表达形式, 可以理解为一一对应的.&lt;/li&gt;
&lt;/ol&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;Assumption 2.&lt;/strong&gt;
&lt;em&gt;Ignorability.&lt;/em&gt; Ttreatment assignment $W$ is independent of the potential outcomes, i.e., $W \perp Y(W = 0), Y(W = 1)$.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;如果满足 Ignorability, 那么会有 2 个 结果成立 :&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;$Y(W = 1)$ 和 $Y(W = 0)$ 的结果与具体施加的treatment 独立无关, 因此我们可以随机的对 groups 施加 treatment;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;表明 group 之间是可以交换的, 即 &lt;em&gt;exchangeability&lt;/em&gt;, group 交换其 treatment, $Y(W = 1)$ 和 $Y(W = 0)$ 的结果不变, 意味着此时 group 之间是可比的, 换句话说, 除了 treatment 不同, 其他条件都会相同.&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/blockquote&gt;
&lt;blockquote&gt;
&lt;p&gt;实际中, 受到背景变量 $X$ 干扰, 导致 $W$ 和 $Y$ 之间有联系. 举个例子,假设 $W$ 表示 &quot;藏私房钱是否被老婆发现&quot;,  $Y(W = 1)$ 表示 &quot;被老婆揍一拳之后的疼痛值&quot;; $W = 0$ 表示没有被揍, $W = 1$ 表示被揍.
此时 $W = 1$ 与 $Y(W = 1)$ 是有联系的, 起码我们知道这样的关系存在: 当 $X = 1$, 即藏私房钱被发现时, $W = 1$ 的取值概率会升高, 同时 $Y(W = 1)$ 会变大.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
因此会考虑以下情况: 给定背景变量 $X$ , 此时有 $W \perp Y(W = 0), Y(W = 1) \mid X$.
:::&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>Deep Reinforcement Learning Series</title><link>https://xuchenhui.cc/posts/2024-04-30-deep-reinforcement-learning/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-30-deep-reinforcement-learning/</guid><description>系统学习强化学习中的策略梯度与价值方法两大类参数更新算法，涵盖 PPO、Q-Learning 等核心方法的原理与公式推导。</description><pubDate>Tue, 30 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;本篇 Blog 主要对强化学习的几个参数更新方法进行学习.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;阅读前, 需要你 : 有高数基础知识, 线代基础知识, 统计学习基础知识, 当然还要有 ML 和 DL 的知识背景.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;1. 总览和相关概念&lt;/h2&gt;
&lt;h3&gt;1.1 总览&lt;/h3&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/30/4Ihfg5X8FtVEbsS.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from David Silver’s RL Course&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;目前, 强化学习的方法基本上划分为 2 大类: policy based and value based.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;policy based
主要是通过学习一个 Actor, 在面对 state 的时候输出预估最优的 action.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;value based
则是想通过学习一个 critic, 来评估当前这个 state 下, 哪个 action 能够得到的分数最高, 进而实现选择 action.&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;当然这么说好像看起来没有什么区别, 后续我们会深入分析他们的区别和联系. 最后还有把 policy based and value based 结合起来的, 就是既有 Actor 也有 Critic.&lt;/p&gt;
&lt;h3&gt;1.2 概念&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;observation&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;比如你打游戏, 这一帧画面就是一个 observation, 有时候也称为一个 state, 就是包含了当前系统的状态和所有信息.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;action&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;就是面对当前画面, 采取的措施, 看到敌人你是躲还是开枪?&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;reward&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;基于当前 observation, 你采取了行动 action, 将得到奖励, 比如开枪干掉了对面, 得到 100 分.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;episode&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这里拿飞机大战举例子, 从你进入游戏, 左右闪躲腾挪, 开枪击毁对面飞机,&lt;strong&gt;直到游戏结束&lt;/strong&gt;, 这就叫一个 episode.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;trajectory&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;还是飞机大战, 从你进入游戏, 左右闪躲腾挪, 开枪击毁对面飞机, 直到游戏结束(假设 T 时间步), 在你这一个 episode 中, 你看到的所有画面帧, 所有的行动, 所有的奖励, 这样一个序列称为一个 trajectory:&lt;/p&gt;
&lt;p&gt;$$
\tau = { \underbrace{s_1,a_1,r_1}_{observation 1,\ action 1,\ reward 1 } \ ... \ s_T,a_T,r_T }
$$&lt;/p&gt;
&lt;h2&gt;2. Policy Based&lt;/h2&gt;
&lt;h3&gt;2.1 要做什么&lt;/h3&gt;
&lt;p&gt;前边提到 Policy Based 要训练一个 Actor. Actor 理解为就是一个
$model \ with \ parameter \ \theta$
, 给定当前的 observation
$s_t$
, Actor 评估当前某个 action
$a_t$
的概率 :
$p(a_t | s_t,\theta)$.&lt;/p&gt;
&lt;p&gt;那怎么训练呢? 假设已经有一个 trajectory $\tau$, 这样的序列除了第一个 state 是初始化, 后续你遇到的每一个 state 都是 Actor 选择的 action 导致的, 因此只需要收集这样一批 trajectory, 每个 trajectory 我们都能收集到相应的奖励 $R(\tau)$, 我们希望这个 Actor 能够在平均的,期望上的奖励能够最大:&lt;/p&gt;
&lt;p&gt;$$
max \ \mathbb{E}&lt;em&gt;{\tau} [R(\tau)] = \sum&lt;/em&gt;{\tau}  R(\tau) p(\tau | \theta)
$$&lt;/p&gt;
&lt;h3&gt;2.2 理论怎么做&lt;/h3&gt;
&lt;h4&gt;2.2.1 计算 $p(\tau | \theta)$&lt;/h4&gt;
&lt;p&gt;$$
\begin{align*}
p(\tau | \theta) &amp;amp;= p(s_1)  p(a_1|s_1,\theta)p(r_1,s_2|s_1,a_1)p(a_2|s_2,\theta)p(r_2,s_3|s_2,a_2)...
\newline
&amp;amp;= \underbrace{p(s_1)}&lt;em&gt;{crate \ by \ env\ } \prod&lt;/em&gt;{t=1}^{T} \underbrace{p(a_t|s_t,\theta)}&lt;em&gt;{crate \ by \ actor\ }\underbrace{p(r_t,s&lt;/em&gt;{t+1}|s_t,a_t)}_{\ crate \ by \ env}
\end{align*}
$$&lt;/p&gt;
&lt;h4&gt;2.2.1 计算梯度&lt;/h4&gt;
&lt;p&gt;通常我们利用梯度方向更新参数, 所以需要计算
$\mathbb{E}_{\tau} [R(\tau)]$
关于 $\theta$
的梯度:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \sum&lt;/em&gt;{\tau} R(\tau) \nabla p_{\theta}(\tau)
\newline
&amp;amp;= \sum_{\tau} R(\tau)  p_{\theta}(\tau) \nabla  log \ p_{\theta}(\tau)
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{\tau} [R(\tau)  \nabla log \ p&lt;/em&gt;{\theta}(\tau) ]
\newline
&amp;amp; \approx \frac{1}{N} \sum_{n=1}^{N} R(\tau^n) \nabla log \ p_{\theta}(\tau^n)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;这里, 对于任意
$\tau$
:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla log \  p_{\theta}(\tau)  &amp;amp;= log \ p(s_1) + \sum_{t=1}^{T} [log \ p(a_t|s_t,\theta) + log \ p(r_t,s_{t+1}|s_t,a_t)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上式中,
$log \ p(s_1)$
与 Actor 无关,
$log \ p(r_t,s_{t+1}|s_t,a_t)$
,也与 Actor 无关, 因此:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla log \  p_{\theta}(\tau)  &amp;amp;= \sum_{t=1}^{T} \nabla log \ p(a_t|s_t,\theta)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;于是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \frac{1}{N} \sum&lt;/em&gt;{n=1}^{N} \sum_{t=1}^{T} R(\tau^n) \nabla log \ p(a_t^n|s_t^n,\theta)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;写成期望的版本:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \mathbb{E}&lt;/em&gt;{(s_t,a_t) \sim p_{\theta}} [R(\tau) \nabla log \ p(a_t |s_t,\theta)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;使用梯度提升更新 Actor 的参数:&lt;/p&gt;
&lt;p&gt;$$
\theta^{new} = \theta^{old} + \eta \nabla  \tilde{R}_{\theta}
$$&lt;/p&gt;
&lt;h4&gt;2.2.2 改进奖励计算方式&lt;/h4&gt;
&lt;p&gt;上边在建模 reward 的时候, 任何 action 的奖励都是正的, 这明显不合理, 因此可以加一个 baseline :&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \frac{1}{N} \sum&lt;/em&gt;{n=1}^{N} \sum_{t=1}^{T} [R(\tau^n) - b] log \ p(a_t^n|s_t^n,\theta)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;一般的, baseline b 可以取值为截至目前的平均 reward:
$b \approx \mathbb{E}_{\tau} [R(\tau)] $.&lt;/p&gt;
&lt;p&gt;此外, 上边式子可以理解为是对
$log \ p(a_t^n|s_t^n,\theta)$
的一个 sum weight. 其中 $R(\tau^n) - b$
表示的是, 在 state $s_t$ 时采取 action $a_t$ 时对后来获得的总奖励的影响是多大.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;另外还可以从减少 variance 的角度来理解, 可移步至此(&lt;a href=&quot;/posts/A-Series-on-LLM-Training-(I)/#11-ppo-%E7%9B%AE%E6%A0%87&quot;&gt;点击跳转&lt;/a&gt;)&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在上边的公式中, 这个 weight 对每个时间步 t 都一样, 均为 $R(\tau^n) - b$. 直觉上, 当前时间步 t 采取的动作, 只能影响 时间步 t 之后的奖励或者状态等. 并且随着时间流逝, 时间步 t 采取的动作对后续的影响应该越来越小, 因此对 $R(\tau)$ 进行修改(忽略 b ):&lt;/p&gt;
&lt;p&gt;$$
R(\tau^n) = \sum_{t&apos; = t}^{T_n} \gamma^{t&apos;-t}r_{t&apos;}^{n}
$$&lt;/p&gt;
&lt;p&gt;其中, $\gamma &amp;lt; 1$, baseline 也同步修改. 举个例子, 假设 $\gamma = 0.99, t = 3 , T = 5$:&lt;/p&gt;
&lt;p&gt;$$
R(\tau^n) = r_3 + 0.99&lt;em&gt;r_4 + 0.99^2&lt;/em&gt;r_5
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;注意到, 此时 $R(\tau^n)$ 是想评估当前 actor 基于当前 state $s_t$ 和 action $a_t$ 的分数, 不妨记为 $Q^{\pi}(s_t,a_t)$, 而 $b$ 可以理解为度量面对当前 state $s_t$ (不与 action 有关), 这个 actor 平均能够取得的分数, 不妨记为 $V^{\pi}(s_t)$, 那如果把 $Q^{\pi}(s_t,a_t) - V^{\pi}(s_t)$ 记为 $A^{\theta}(s_t,a_t)$, 其实这就是 Value Based 方法 (也就是一个 critic). 二者结合, 就是 Actor - Critic, 我们后续会讨论. 此时, 之前的期望公式就可以写为:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \mathbb{E}&lt;/em&gt;{(s_t,a_t) \sim p_{\theta}} [A^{\theta}(s_t,a_t) \nabla  log \ p(a_t |s_t,\theta)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;:::note
:::&lt;/p&gt;
&lt;h3&gt;2.3 实际怎么做&lt;/h3&gt;
&lt;p&gt;前边的更新策略有个大问题就是, 我们要收集数据, 这个需要一轮一轮的玩下去才能收集到这些信息. 而且更新的这个 Actor 和 环境进行交互的 Actor 是同一个, 这就导致收集一批数据, 更新 Actor 之后, 整个过程就得停下来, 用新的 Actor 再次和环境进行交互 (这个过程称为 Online-policy). 这样就会很慢, 我们想着能不能让当前的 Actor 借助别人的力量, 使用别人的历史数据去更新?(这个过程称为 Offline-policy)&lt;/p&gt;
&lt;h4&gt;2.3.1 Importance Sampling&lt;/h4&gt;
&lt;p&gt;首先介绍一个 trick, 假设我们要计算
$\mathbb{E}_{x \sim p} [f(x)] $
, 但是 $p(x)$ 也许不好计算, 不好得到, 但是我们有一个分布 $q(x)$, 计算很容易, 也很容易积分:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\mathbb{E}&lt;em&gt;{x \sim p} [f(x)]  &amp;amp;= \int f(x)p(x) \ d(x)
\newline
&amp;amp;= \int f(x)p(x) \frac{q(x)}{q(x)} \ d(x)
\newline
&amp;amp;= \int f(x) \frac{p(x)}{q(x)} q(x) \ d(x)
\newline
&amp;amp;= \mathbb{E}&lt;/em&gt;{x \sim q} [f(x)\frac{p(x)}{q(x)}]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;可以看到, 本来是从 $p(x)$
抽取数据, 现在做到从 $q(x)$
抽取数据, 并且期望还不变. 但是需要注意的是, 他们的方差是不一样的.&lt;/p&gt;
&lt;p&gt;首先:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
Var_{x \sim p} [f(x)] &amp;amp;= \mathbb{E}&lt;em&gt;{x \sim p} [f(x)^2] - (\mathbb{E}&lt;/em&gt;{x \sim p} [f(x)] ) ^2
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;而:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
Var_{x \sim q} [f(x) \frac{p(x)}{q(x)} ] &amp;amp;= \mathbb{E}&lt;em&gt;{x \sim q} [(f(x)\frac{p(x)}{q(x)})^2] - (\mathbb{E}&lt;/em&gt;{x \sim q} [f(x)\frac{p(x)}{q(x)}] ) ^2
\newline
&amp;amp;= \int f(x)^2 \frac {p(x) p(x)} {q(x) q(x)} q(x)\ d(x) - (\int q(x) f(x) \frac {p(x)}{q(x)} \ d(x)) ^ 2
\newline
&amp;amp;=  \mathbb{E}&lt;em&gt;{x \sim p} [f(x)^2\frac {p(x)}{q(x)}] - (\mathbb{E}&lt;/em&gt;{x \sim p} [f(x)] ) ^2
\newline
&amp;amp; \neq \mathbb{E}&lt;em&gt;{x \sim p} [f(x)^2] - (\mathbb{E}&lt;/em&gt;{x \sim p} [f(x)] ) ^2
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;二者方差差一点, 因此要求 $p(x)$ 和 $q(x)$ 不要差太多.&lt;/p&gt;
&lt;h4&gt;2.3.2 off - policy&lt;/h4&gt;
&lt;p&gt;off-policy 就是利用上边的 importance sampling, 使用另外一个 policy ${\theta_{old}}$ 的 {( $s_t,a_t$ )} 去更新 policy ${\theta}$. 于是,&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta}  &amp;amp;= \mathbb{E}&lt;/em&gt;{\tau} [R(\tau)  \nabla log \ p_{\theta}(\tau) ]
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{\tau \sim p&lt;/em&gt;{\theta_{old}}(\tau)} [\frac { p_{\theta}(\tau)} { p_{\theta_{old}}(\tau)} R(\tau)  \nabla log \ p_{\theta}(\tau) ]
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{(s_t,a_t) \sim p&lt;/em&gt;{\theta_{old}}} [\frac { p_{\theta}(a_t , s_t)} { p_{\theta_{old}}(a_t , s_t)} R(\tau)  \nabla log \ p_{\theta}(\tau) ]
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{(s_t,a_t) \sim p&lt;/em&gt;{\theta_{old}}} [\frac { p_{\theta}(a_t | s_t)p_{\theta}(s_t)} { p_{\theta_{old}}(a_t | s_t)p_{\theta_{old}}(s_t)} R(\tau)  \nabla log \ p_{\theta}(\tau) ] \ (p_{\theta}(s_t) \approx p_{\theta_{old}}(s_t))
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{(s_t,a_t) \sim p&lt;/em&gt;{\theta_{old}}} [\frac { p_{\theta}(a_t | s_t)} { p_{\theta_{old}}(a_t | s_t)} R(\tau)  \nabla log \ p_{\theta}(\tau) ]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上式 $p_{\theta}(s_t) \approx p_{\theta_{old}}(s_t)$ 是我们进行的假设, 假设环境在第 t 步出现的状态和 Actor 无关( 看起来不太合适, 这个也是没办法的办法 ). 此外 policy ${\theta_{old}}$ 是固定的, 用来和环境交互, policy ${\theta}$ 是我们要更新的. 更新一段时间后, 我们可以执行 $\theta_{old} &amp;lt;-- \theta$ 以防二者差距太大.&lt;/p&gt;
&lt;p&gt;这样, 反推得到:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\tilde{R}&lt;em&gt;{\theta}^{\theta&lt;/em&gt;{old}}  &amp;amp;= \mathbb{E}&lt;em&gt;{(s_t,a_t) \sim p&lt;/em&gt;{\theta_{old}}} [\frac { p_{\theta}(a_t | s_t)} { p_{\theta_{old}}(a_t | s_t)} R(\tau)]
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;如果使用 Actor - Critic 的方式评估 $R(\tau)$,之前的期望公式就可以写为:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
\nabla  \tilde{R}&lt;em&gt;{\theta} &amp;amp;= \mathbb{E}&lt;/em&gt;{(s_t,a_t) \sim p_{\theta_{old}}} [\frac { p_{\theta}(a_t | s_t)} { p_{\theta_{old}}(a_t | s_t)} A^{\theta_{old}}(s_t,a_t)  \nabla log \ p_{\theta}(\tau) ]
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;反推得到&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
\tilde{R}&lt;em&gt;{\theta}^{\theta&lt;/em&gt;{old}}  &amp;amp;= \mathbb{E}&lt;em&gt;{(s_t,a_t) \sim p&lt;/em&gt;{\theta_{old}}} [\frac { p_{\theta}(a_t | s_t)} { p_{\theta_{old}}(a_t | s_t)} A^{\theta_{old}}(s_t,a_t)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;:::note
:::&lt;/p&gt;
&lt;h3&gt;2.4 PPO&lt;/h3&gt;
&lt;p&gt;在上边的基础上, 我们来看经典算法 PPO. PPO 其实就是在上边的基础上, 加了一个 KL 散度(&lt;a href=&quot;https://chenhui-x.github.io/posts/Kullback-Leibler-divergence/&quot;&gt;什么是 KL 散度?&lt;/a&gt;), 这是因为我们要尽量保证 $\theta_{old} \approx \theta$. 这里直接贴出原始 &lt;a href=&quot;https://arxiv.org/abs/1707.06347&quot;&gt;paper&lt;/a&gt; 的公式, 现在一目了然:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/30/ShDp81dTJAUbjKF.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;这里 KL (
$\theta_{old}$
,
$\theta$)
实际上就是让二者的输出, 计算一下离散的 KL 散度值作为代替. 不过作者又给了一个更加简单粗暴的实现: 如果二者差距确实大,直接 clip 一下就完事了:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/30/4e5jVbw1UGNr3Q9.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;这样就把
$\frac { p_{\theta}(a_t | s_t)} { p_{\theta_{old}}(a_t | s_t)}$
限制到了
$(1- \epsilon , 1 + \epsilon)$.&lt;/p&gt;
&lt;h2&gt;3. Value Based&lt;/h2&gt;
&lt;h3&gt;3.1 概念&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;[1]&lt;/strong&gt;
Value Based 的方法目标就是训练一个 critic, 也可以叫一个 function, 功能就是给定一个 state (或者以及一个 action), critic 能够评估当前这个 Actor(policy : $\pi$) 最后能取得多少分数(相对的,平均).&lt;/p&gt;
&lt;p&gt;不妨记,
$V^{\pi}(s)$
表示给定一个 state s, critic 给出的基于当前 state, 该 policy 能得到的分数(平均).&lt;/p&gt;
&lt;p&gt;$$
V^{\pi}(s) = \mathbb{E}&lt;em&gt;{\pi}[R_t | S_t = s] =  \mathbb{E}&lt;/em&gt;{\pi}[\sum_{k = 0}^{\infty} \gamma^{k}R_{t+k} | S_t = s]
$$&lt;/p&gt;
&lt;p&gt;如果对于任意的 state s, 如果 $V^{\pi&apos;}(s)  &amp;gt;= V^{\pi}(s) $ 恒成立, 则称 $\pi&apos;$ better than $\pi$. 于是这样, 我们整个系统中, 最好的那个
$\pi$ 记为 $\pi_{\ast}$ . 并将 $\pi_{\ast}$ 对应的 V 记为 $V^{\ast}(s)$. 即&lt;/p&gt;
&lt;p&gt;$$
V^{\ast}(s) = \mathop{\max}\limits_{\pi}  \ V^{\pi}(s)
$$&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;[2]&lt;/strong&gt;
$Q^{\pi}(s,a)$
表示给定一个 state s, 然后 Actor 采取一个 action a, critic 给出的基于当前 state 和 action, 该 Actor 能最后得到的分数(平均).&lt;/p&gt;
&lt;p&gt;$$
Q^{\pi}(s,a) = \mathbb{E}&lt;em&gt;{\pi}[R_t | S_t = s , A_t = a ] =  \mathbb{E}&lt;/em&gt;{\pi}[\sum_{k = 0}^{\infty} \gamma^{k}R_{t+k} | S_t = s , A_t = a  ]
$$&lt;/p&gt;
&lt;p&gt;同理, $\pi_{\ast}$ 在每个 state 上采取的 action a, 将 Q 值是最大记为如下形式:&lt;/p&gt;
&lt;p&gt;$$
Q^{\ast}(s,a) = \mathop{\max}\limits_{\pi}  \ Q^{\pi}(s,a)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;对于 $V^{\pi}(s)$ 和 $Q^{\pi}(s,a)$, 二者有以下关系:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
Q^{\pi}(s,a) =  \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\pi}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;当然, 对于最优的 value,那么会有如下的式子成立:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
Q^{\ast}(s,a) =  \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\ast}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
$$&lt;/p&gt;
&lt;p&gt;:::note
:::&lt;/p&gt;
&lt;h3&gt;3.2 Bellman equation&lt;/h3&gt;
&lt;h4&gt;3.2.1 Bellman equation for $V^{\pi}(s)$&lt;/h4&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi}(s) &amp;amp;= \mathbb{E}&lt;em&gt;{\pi}[R_t | S_t = s]
\newline
&amp;amp;=  \mathbb{E}&lt;/em&gt;{\pi}[\sum_{k = 0}^{\infty} \gamma^{k}R_{t+k} | S_t = s]
\newline
&amp;amp;= \mathbb{E}&lt;em&gt;{\pi}[r_t + \sum&lt;/em&gt;{k = 1}^{\infty} \gamma^{k}R_{t+k} | S_t = s]
\newline
&amp;amp;= \sum_{a} \pi(a | s) \sum_{s&apos;} \sum_{r} p(s&apos;, r | s, a) [r + \gamma \mathbb{E}&lt;em&gt;{\pi}[\sum&lt;/em&gt;{k = 0}^{\infty} \gamma^{k}R_{t+k+1} | S_{t+1} = s&apos;]]
\newline
&amp;amp;= \sum_{a} \pi(a | s) \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上边这个式子称为 $V^{\pi}(s)$ 的 &lt;code&gt;Bellman equation&lt;/code&gt;, 它揭示了当前 $\text{state s}$ 下的 $V^{\pi}$ 与下一时刻的 $\text{state s&apos;}$ 的 $V^{\pi}$ 之间的关系.&lt;/p&gt;
&lt;h4&gt;3.2.2 Bellman optimality equation for $V^{\pi}(s)$&lt;/h4&gt;
&lt;p&gt;从定义上我们有:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\ast}(s) &amp;amp;= \mathop{\max}\limits_{a}  \ Q^{\pi_{\ast}}(s,a)
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi&lt;/em&gt;{\ast}}[R_t | S_t = s, A_t = a]
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi&lt;/em&gt;{\ast}}[\sum_{k = 0}^{\infty} \gamma^{k}R_{t+k} | S_t = s, A_t = a]
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi&lt;/em&gt;{\ast}}[r_t + \gamma \sum_{k = 0}^{\infty} \gamma^{k}R_{t+k+1} | S_t = s, A_t = a]
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi&lt;/em&gt;{\ast}}[r_t + \gamma V^{\ast}(s_{t+1})| S_t = s, A_t = a]
\newline
&amp;amp;= \mathop{\max}\limits_{a \in A(S)} \sum_{s&apos;,\ r } p(s&apos;, r | s, a)[r_t + \gamma V^{\ast}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上式称为 Bellman optimality equation for $V^{\pi}(s)$&lt;/p&gt;
&lt;h4&gt;3.2.3 Bellman optimality equation for $Q^{\pi}(s,a)$&lt;/h4&gt;
&lt;p&gt;按定义我们有:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
Q^{\ast}(s,a) &amp;amp;=  \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\ast}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
\newline
&amp;amp;=  \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma \mathop{\max}\limits&lt;/em&gt;{a&apos;} Q^{\ast}(s_{t+1},a&apos;) | S_t = s , A_t = a ]
\newline
&amp;amp;= \sum_{s&apos;,\ r } p(s&apos;, r | s, a)[r_t + \gamma Q^{\ast}(s&apos;,a&apos;))]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上式称为 Bellman optimality equation for $Q^{\pi}(s,a)$&lt;/p&gt;
&lt;h3&gt;3.3 Dynamic Programming Based&lt;/h3&gt;
&lt;h4&gt;3.3.1 Policy Evaluation&lt;/h4&gt;
&lt;p&gt;动态规划的思想就很直观, 直接根据我们之前的 Bellman equation 迭代就行了, 因为 Bellman equation 描述的就是 $\pi_{\ast}$ 本身的性质. 更新迭代公式如下:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi}(s) \Leftarrow \sum_{a} \pi(a | s) \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上述迭代过程称为 iterative policy evaluation. 此外, 也可以使用 Bellman equation for $Q^{\pi}(s,a)$ 做迭代, 就是直接对 $Q^{\pi}(s,a)$ 迭代. 这个称为 Q-policy Iteration.&lt;/p&gt;
&lt;h4&gt;3.3.2 Policy Improvement&lt;/h4&gt;
&lt;p&gt;我们的目是找一个 policy, 能够面对不同的 state 采取一个 action. 已知&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
Q^{\pi}(s,a) &amp;amp;=  \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\pi}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
\newline
&amp;amp;= \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;现在我们让 $Q^{\pi}(s,a)$ 最大, 这会得到一个相应的 action $a^{\ast}$ , 如果我们让另外一个 $\pi&apos;$ 每次都能让以下式子成立:&lt;/p&gt;
&lt;p&gt;$$
\pi&apos;(s) = a^{\ast}
$$&lt;/p&gt;
&lt;p&gt;这就意味着, 这个 $\pi&apos;$ 每次都能采取最优的 action. 从而总是有:&lt;/p&gt;
&lt;p&gt;$$
V^{\pi&apos;}(s) &amp;gt;= V^{\pi}(s)
$$&lt;/p&gt;
&lt;p&gt;于是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\pi&apos;(s) &amp;amp;= \mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s,a)
\newline
&amp;amp;= \mathop{\arg\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\pi}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
\newline
&amp;amp;= \mathop{\arg\max}\limits_{a}  \ \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上述过程称为 Policy Improvement. 现在, 如果更新后的 policy 和原始的 policy 一样, 即 $V^{\pi} = V^{\pi&apos;}$, 于是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi&apos;}(s) &amp;amp;= \mathop{\max}\limits_{a}  \ Q^{\pi}(s,a)
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \mathbb{E}&lt;em&gt;{\pi}[r_t + \gamma V^{\pi}(s&lt;/em&gt;{t+1})| S_t = s , A_t = a ]
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi}(s&apos;)]
\newline
&amp;amp;= \mathop{\max}\limits_{a}  \ \sum_{s&apos;, \ r} p(s&apos;, r | s, a) [r + \gamma V^{\pi&apos;}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;可以看到, 这就是 3.2.2 的 Bellman optimality equation for $V^{\pi}(s)$. 换句话说, 此时的 $\pi&apos; = \pi$ 就是最优的 $\pi$.&lt;/p&gt;
&lt;h4&gt;3.3.3 Policy Iteration&lt;/h4&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/03/9s3hY5weflTpL6d.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;我们使用 Policy Evaluation 去迭代 $V^{\pi}(s)$, 实现让 $V^{\pi}(s)$ 预估的更加准确.&lt;/p&gt;
&lt;p&gt;然后, 我们使用 Policy Improvement 去找到一个更好的 $\pi$. 上述过程称为 Policy Iteration.&lt;/p&gt;
&lt;p&gt;这个过程 Policy Evaluation 和 Policy Improvement 是交替循环的, 如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/MlOrPbRqevAgFU1.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;后续, 多个方法思路都是类似的, 上述循环过程称为 Generalized Policy Iteration.&lt;/p&gt;
&lt;h4&gt;3.3.4 Value Iteration&lt;/h4&gt;
&lt;p&gt;Policy Iteration 有 2 个步骤, 首先要让 V 预估准确, 然后再使用 Policy Improvement, 步骤比较繁琐, Value Iteration 的思想是, 直接从 Bellman optimality equation for $V^{\pi}(s)$ 入手 :&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\ast}(s) &amp;amp;= \mathop{\max}\limits_{a \in A(S)} \sum_{s&apos;,\ r } p(s&apos;, r | s, a)[r_t + \gamma V^{\ast}(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上述式子描述的是, 最优的 $\pi_{\ast}$ 能够满足的式子, 于是我们直接使用这个式子进行迭代:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V(s) &amp;amp;= \mathop{\max}\limits_{a \in A(S)} \sum_{s&apos;,\ r } p(s&apos;, r | s, a)[r_t + \gamma V(s&apos;)]
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;当上式迭代收敛后, 得到的一个预估准确的 V function, 并且 $V(s)$ 的结果就是最优的 $\pi_{\ast}$ 对应的 V 值. 算法如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/Px6KeON8TqJmHVn.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h4&gt;3.3.5 DP 方法总结&lt;/h4&gt;
&lt;p&gt;DP 方法总体思路为迭代方法, 主要基于 Bellman equation 进行迭代更新. 它基于当前状态, 观察所有可能的下一步来更新. 如图所示:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/arfKpUyeJ692CsY.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://www.youtube.com/watch?v=P0ZvxeQqv0A&quot;&gt;Monte Carlo Methods&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h3&gt;3.4 Monte-Carlo based&lt;/h3&gt;
&lt;h4&gt;3.4.1 state value based&lt;/h4&gt;
&lt;p&gt;Monte-Carlo 方法就很质朴, 直接基于当前 state, 然后你玩游戏直到结束, 记录分数, 最后求和得到累计奖励 $R$, 然后 minimize 二者的差距:&lt;/p&gt;
&lt;p&gt;$$
minimize \ V^{\pi}(s) \leftrightarrow  R
$$&lt;/p&gt;
&lt;p&gt;具体的, 对于某个具体的 state $s$, 收集一批以 s 为起点的 trajectory, 最后计算 reward 的均值作为 $V^{\pi}(s)$:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/5QwhUXIr96KiNSP.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;思想如图所示:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/1TUgNV5yYWfdeF3.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://www.youtube.com/watch?v=P0ZvxeQqv0A&quot;&gt;Monte Carlo Methods&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h4&gt;3.4.2 state action value based&lt;/h4&gt;
&lt;p&gt;但是 state value based 可能比较困难, 因为要预估当前 state 下, 所有 action 的结果. 如果直接预估当前 state 下, 采取一个 action 之后的值可能好一点. 具体的, 收集一批以 state s , action a 为起点的 trajectory , 最后计算 reward 的均值作为 $Q^{\pi}(s,a)$:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/qiZp7dsTecbukfw.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;上述过程也是一个 cycle, 我们使用 Monte-Carlo 方法得到一个预估准确的 Q function, 然后使用这个 Q function 去让 $\pi(s) = \mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s,a) $ 实现 improvement 的操作. 这种取 max 的操作也叫做 $greedy$, 考虑到 &lt;a href=&quot;https://huggingface.co/learn/deep-rl-course/unit1/exp-exp-tradeoff&quot;&gt;Exploration/Exploitation trade-off&lt;/a&gt; , 如果每次都取 max, 会抑制后续的 exploration. 因此引入 $\epsilon - greedy$, 算法如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/04/jf4X5lavU3OVeht.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h3&gt;3.5 Temporal-difference based&lt;/h3&gt;
&lt;h4&gt;3.5.1 迭代公式&lt;/h4&gt;
&lt;p&gt;回顾 $V^{\pi}(s)$ 的定义: 计算 $\pi$ 面对当前 state s 能够获得的奖励, 记 $N_{a}^{k}(s)$ 表示基于当前 state 采样的 action 数目.&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V_{k}^{\pi}(s) &amp;amp;= \mathbb{E}&lt;em&gt;{\pi}[R_t | S_t = s]
\newline
&amp;amp;= \frac {1} {N&lt;/em&gt;{a}^{k}(s)} (R_1 + R_2 + ... + R_{N_{a}^{k}(s)}) \ \ (采样)
\newline
&amp;amp;= \frac {1} {N_{a}^{k}(s)} (  R_{N_{a}^{k}(s)} + \sum_{i}^{N_{a}^{k}(s) - 1 } R_i )
\newline
&amp;amp;= \frac {1} {N_{a}^{k}(s)} (  R_{N_{a}^{k}(s)} + (N_{a}^{k}(s)-1) V_{k -1 }^{\pi}(s) + V_{k -1 }^{\pi}(s) - V_{k -1 }^{\pi}(s))
\newline
&amp;amp;= \frac {1} {N_{a}^{k}(s)} (  R_{N_{a}^{k}(s)} + N_{a}^{k}(s) V_{k -1 }^{\pi}(s)  - V_{k -1 }^{\pi}(s))
\newline
&amp;amp;= V_{k -1 }^{\pi}(s) + \frac {1} {N_{a}^{k}(s)} (  R_{N_{a}^{k}(s)}   - V_{k -1 }^{\pi}(s))
\newline
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;抽象一下, 可以表示为下边的迭代公式:&lt;/p&gt;
&lt;p&gt;$$
\text{NewEstimate ← OldEstimate + StepSize [Target − OldEstimate]}
$$&lt;/p&gt;
&lt;p&gt;$V_{k -1 }^{\pi}(s)$ 告诉我们当前的预估, $R_{N_{a}^{k}(s)}$ 是真实看到的结果, $R_{N_{a}^{k}(s)}   - V_{k -1 }^{\pi}(s)$ 告诉我们应该向实际看到的 reward 方向走, 这很像智能优化中的粒子群算法(关于该算法可以看我的&lt;a href=&quot;https://www.bilibili.com/video/BV1uY41187rK&quot;&gt;视频讲解&lt;/a&gt;).&lt;/p&gt;
&lt;h4&gt;3.5.2 TD-Prediction&lt;/h4&gt;
&lt;p&gt;假设我们的 critic 是准确的, 应该有:&lt;/p&gt;
&lt;p&gt;$$
V^{\pi}(s_t) = \gamma V^{\pi}(s_{t+1}) + r_t
$$&lt;/p&gt;
&lt;p&gt;易知:&lt;/p&gt;
&lt;p&gt;$$
V^{\pi}(s_t) = (1 - \alpha) V^{\pi}(s_t) + \alpha V^{\pi}(s_t)
$$&lt;/p&gt;
&lt;p&gt;从而:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi}(s_t) &amp;amp;= (1 - \alpha) V^{\pi}(s_t) + \alpha ( \gamma V^{\pi}(s_{t+1}) + r_t)
\newline
&amp;amp;= V^{\pi}(s_t) - \alpha V^{\pi}(s_t) + \alpha ( \gamma V^{\pi}(s_{t+1}) + r_t)
\newline
&amp;amp;= V^{\pi}(s_t) + \alpha ( \gamma V^{\pi}(s_{t+1}) + r_t - V^{\pi}(s_t))
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;上式中, $\gamma V^{\pi}(s_{t+1}) + r_t $ 描述的是向后预估一个 state $s_{t+1}$ 的 value. 然后再回过头来, 结合当前的预估值看预估的准不准, &lt;strong&gt;$\gamma V^{\pi}(s_{t+1}) + r_t - V^{\pi}(s_t)$ 也称为 temporal difference error (TD-error)&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;那如果 $V^{\pi}(s)$ 不准确, 我们可以使用 $\gamma V^{\pi}(s_{t+1}) + r_t$ 来纠正 $V^{\pi}$(因为 $r_t$ 至少是确定的).&lt;/p&gt;
&lt;p&gt;于是,可以使用如下迭代公式更新 $V^{\pi}(s)$ :&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi}(s_t) \leftarrow  V^{\pi}(s_t) + \alpha ( \gamma V^{\pi}(s_{t+1}) + r_t - V^{\pi}(s_t))
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;前边的 Monte-Carlo based 也可以写作如下式子:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
V^{\pi}(s_t) \leftarrow  V^{\pi}(s_t) + \alpha ( R_t - V^{\pi}(s_t))
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;所以 Monte-Carlo based 就是直接用&lt;strong&gt;真实的&lt;/strong&gt;、整个 trajectory 的 reward 与 $V^{\pi}(s)$ 做比较, 而 Temporal-difference based 则是向后玩一步或者多步, 剩余的使用 $V^{\pi}$ 进行预估.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
:::
当然, 由于现在大家都是 neural network , 因此也可以直接使用梯度下降 minimize 以下差异 &lt;a href=&quot;https://youtu.be/o_g9JUMw1Oc?t=924&quot;&gt;DRL Lecture 3: Q-learning&lt;/a&gt;:&lt;/p&gt;
&lt;p&gt;$$
minimize \ V^{\pi}(s_t) - \gamma V^{\pi}(s_{t+1}) \leftrightarrow  r_t
$$&lt;/p&gt;
&lt;p&gt;不过需要注意的是, MC 方法和 TD 方法有时候预估出来的结果可能不一样:
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/30/QveSp4AUZ7ykamt.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://youtu.be/o_g9JUMw1Oc?t=924&quot;&gt;DRL Lecture 3: Q-learning&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;二者没有说谁对谁错, 只是基于当前的数据, 作出的合理的判断. 不过由于便捷性和效率, &lt;strong&gt;通常使用 TD 方法, 毕竟 MC 方法太磨叽了.&lt;/strong&gt;
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;3.5.3 SARSA: ON-POLICY TD CONTROL&lt;/h4&gt;
&lt;p&gt;我们也可以直接对 $Q^{\pi}(s,a)$ 进行 TD-Prediction, 该算法也叫 SARSA (State-Action-Reward-State-Action) .&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/05/aZwMWOVS2yA1hKQ.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;这个地方 ON-POLICY 是说我们对下一个 action $s_{t+1}$ 的 Q value 预估用的 policy 和获取下一个 action $s_{t+1}$ 的 policy 是同一个.&lt;/p&gt;
&lt;h4&gt;3.5.4 Q-Learning: Off-Policy TD Control&lt;/h4&gt;
&lt;p&gt;Q Learning 则是直接类似 3.3.4 节的 Value Iteration. 我们可以直接将 improvement 嵌入到更新公式里边, 直接期望 Q function 收敛到最优的 policy $\pi_{\ast}$ 对应的 Q. 算法如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/05/xVeEJQyK5LqDYtA.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://youtu.be/o_g9JUMw1Oc?t=924&quot;&gt;DRL Lecture 3: Q-learning&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;Q-Learning 也被称为 Off-Policy, 是因为我们计算 $r_t + \mathop{\max}\limits_{a}  \ Q^{\pi}(s_{t+1},a)$ 的时候用的是 $\mathop{\max}\limits_{a}  \ Q^{\pi}(s_{t+1},a)$, 而不是 $\pi$ 真正想输出的 action.&lt;/p&gt;
&lt;p&gt;可以这样理解, 存在一个 $\pi_{\ast}$, 使得:&lt;/p&gt;
&lt;p&gt;$$
\pi_{\ast}(s_{t+1}) = \mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s_{t+1},a) \ \text{或者} \  Q^{\pi}(s_{t+1},\pi_{\ast}(s_{t+1})) = \mathop{\max}\limits_{a}  \ Q^{\pi}(s_{t+1},a)
$$&lt;/p&gt;
&lt;p&gt;于是, 每次对下一个 action $s_{t+1}$ 的 Q value 预估的时候, 实际上用的 policy 是 $\pi_{\ast}$, 而基于当前 state $s_t$ 采取 action 的 policy 是 $\pi$. 二者不是同一个, 因此叫 Off-Policy.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;这个过程就是&quot;培养&quot; $\pi$, 去尽量向着最优的 $\pi_{\ast}$ 给的方向走.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr /&gt;
&lt;p&gt;实作上, 由于大家现在都是 neural network 了, 可以直接使用梯度下降去 $\text{minimize}$ 误差:&lt;/p&gt;
&lt;p&gt;$$
\text{minimize} \ Q^{\pi}(s_i,a_i) \leftrightarrow  r_i + \mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s_{i+1},a), where \ a = \pi(s_{i+1})
$$&lt;/p&gt;
&lt;p&gt;算法如下:&lt;/p&gt;
&lt;p&gt;[1]. 初始化 Q-function $Q^{\pi}(s,a)$, target Q-function $\tilde{Q^{\pi}}(s,a)$&lt;/p&gt;
&lt;p&gt;[2]. 然后对每个 state 都采取 action a, where $a = \mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s,a)$&lt;/p&gt;
&lt;p&gt;[3]. 这样就能收集到一批 4 元对 : {$s_t,a_t,r_t,s_{t+1}$} 到 buffer 里边(buffer 里边的数据要及时更换, 把太早的丢掉, 用更新的 $Q^{\pi}$ 产生的数据放进去.).&lt;/p&gt;
&lt;p&gt;[4]. 从 buffer 里边 sample 一笔数据, {$s_i,a_i,r_i,s_{i+1}$}.&lt;/p&gt;
&lt;p&gt;[5]. 由于等式左右两边都在变, 考虑到稳定性, 我们用 target Q-function (fixed) 去替换 $\mathop{\arg\max}\limits_{a}  \ Q^{\pi}(s_{i+1},a)$, 于是优化目标变为:&lt;/p&gt;
&lt;p&gt;$$
minimize \ Q^{\pi}(s_i,a_i) \leftrightarrow  r_i + \mathop{\arg\max}\limits_{a}  \boldsymbol{\tilde{Q^{\pi}}(s_{i+1},a)}
$$&lt;/p&gt;
&lt;p&gt;上边的式子还有一个问题, 就是后边&lt;/p&gt;
&lt;p&gt;$$
\mathop{\arg\max}\limits_{a}  \ \tilde{Q^{\pi}}(s_{i+1},a)
$$&lt;/p&gt;
&lt;p&gt;完全是由 Target Net 来选择高分的 action. &lt;a href=&quot;https://arxiv.org/abs/1509.06461&quot;&gt;Double DQN&lt;/a&gt; 发现, Target Net 总是高估自己的 action 的分数, 于是提出用 2 个 net 相互制衡, 实作很简单, 直接 action 输出使用正在更新的 $\pi$ 即可, 然后打分还是用 $\tilde{Q^{\pi}}$ :&lt;/p&gt;
&lt;p&gt;$$
\mathop{\arg\max}\limits_{a}  \ \tilde{Q^{\pi}}(s_{i+1},Q^{\pi}(s_{i+1},a))
$$&lt;/p&gt;
&lt;p&gt;换个更常见的写法, 优化目标最终变为:&lt;/p&gt;
&lt;p&gt;$$
minimize \ Q^{\pi}(s_i,a_i) \leftrightarrow  r_i + \boldsymbol{\tilde{Q^{\pi}}(s_{i+1},\mathop{\arg\max}\limits_{a} Q^{\pi}(s_{i+1},a))}
$$&lt;/p&gt;
&lt;p&gt;[7]. if step % C = 0, $\tilde{Q^{\pi}} = Q^{\pi}$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;这里需要注意的是, $\pi$ 只有一个, 只是 Q function 有 2 个, 其中 $\tilde{Q^{\pi}}$ 的引入只是为了更新的稳定性(如果不考虑稳定性, 那就是原始算法). 但是无论如何, 目标就是要让 $\pi$ 逼近潜在的最优的 $\pi_{\ast}$. (off-policy)
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;3.5.5 SARSA VS Q-Learning&lt;/h4&gt;
&lt;p&gt;这里有一个例子, 图中 Cliff 区域的奖励是 -100, 其他区域奖励为 -1. 可以看到 Q-Learning 尽管每次 action 的选取用到了 $\epsilon - greedy$, 但是我们做 Q 值预测的时候, 总是选择 $\text{max}$ 的, 这就导致最后 Q-Learning 收敛到 optimize policy. 而 SARSA 得到则是相对次优的:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/05/Dhf3iG1WpzymvdF.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h4&gt;3.5.6 DP VS TD&lt;/h4&gt;
&lt;p&gt;DP 方法参考 3.3 节. 下边给出 DP 和 TD 方法的区别和联系.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/RHEGAv73d5Wru9t.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/KENkys8WxcmGUVo.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h4&gt;3.5.7 Forward View of TD-$\lambda$&lt;/h4&gt;
&lt;p&gt;前边我们只是向后观察 1 步:&lt;/p&gt;
&lt;p&gt;$$
Q(s_t,a_t) = Q(s_t,a_t)  + \alpha (r_t + \gamma Q(s_{t+1},a_{t+1}) -  Q(s_t,a_t))
$$&lt;/p&gt;
&lt;p&gt;我们可以向后观察 2 步:&lt;/p&gt;
&lt;p&gt;$$
Q(s_t,a_t) = Q(s_t,a_t)  + \alpha (r_t + r_{t+1} + \gamma ^ 2 Q(s_{t+2},a_{t+2}) -  Q(s_t,a_t))
$$&lt;/p&gt;
&lt;p&gt;可以向后观察 k 步:&lt;/p&gt;
&lt;p&gt;$$
Q(s_t,a_t) = Q(s_t,a_t)  + \alpha (r_t + r_{t+1} + ... + r_{t+k-1}  + \gamma ^ k Q(s_{t+k},a_{t+k}) -  Q(s_t,a_t))
$$&lt;/p&gt;
&lt;p&gt;上式全部为 $Q(s_t,a_t)$, 我们可以使用加权取平均对所有结果, 记:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
G_t^1 &amp;amp;= r_t + \gamma Q(s_{t+1},a_{t+1})
\newline
G_t^2 &amp;amp;= r_t + r_{t+1} + \gamma ^ 2 Q(s_{t+2},a_{t+2})
\newline
...
\newline
G_t^k &amp;amp;= r_t + r_{t+1} + ... + r_{t+k-1}  + \gamma ^ k Q(s_{t+k},a_{t+k})
\newline
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;$\lambda$ 加权平均:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
G &amp;amp;= \sum^{k \rightarrow \infty} \frac {1}{1 + \lambda + ... + \lambda^{k-1} } (G_t^1  + \lambda G_t^2 + \lambda^2 G_t^3 + ... + \lambda^{k-1} G_t^k)
\newline
&amp;amp;= (1 - \lambda ) \sum^{k \rightarrow \infty} (G_t^1  + \lambda G_t^2 + \lambda^2 G_t^3 + ... + \lambda^{k-1} G_t^k)
\newline
&amp;amp;= (1 - \lambda ) \sum^{k \rightarrow \infty} \lambda^{k-1} G_t^k
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;其中, $\frac {1}{1 + \lambda + ... + \lambda^{k-1} }$ 使得权重求和为 1.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/05/yB2vdxL4s9eqQKD.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;上述版本称为 Forward View of TD(λ). 原因是我们站在当前 time step 向后观察每个 state 的情况, 越往后的 state, 所分配的更新权重越小(对当前的 state $s_t$ 影响越小):&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/07/XJTtUIEQxjAO2mR.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h4&gt;3.5.8 Backward View of TD-$\lambda$&lt;/h4&gt;
&lt;p&gt;这里需要引入 eligibility trace. 其定义公式如下:&lt;/p&gt;
&lt;p&gt;$$
E(S) \leftarrow \gamma \lambda E(S) + \mathbb{1}(S = s)
$$&lt;/p&gt;
&lt;p&gt;上述公式的逻辑是这样的. 假设有一个特定的 state s. 如果当前这轮更新 Q 的时候, 遇到的 state 就是 s, 那么&lt;/p&gt;
&lt;p&gt;$$
E(s) \leftarrow \gamma \lambda E(s) + 1
$$&lt;/p&gt;
&lt;p&gt;否则, 就是当前 state 是其他的 $\text{s&apos;}$, 那么&lt;/p&gt;
&lt;p&gt;$$
E(s) \leftarrow \gamma \lambda E(s)
$$&lt;/p&gt;
&lt;p&gt;如果一个 state s 经常多次出现, 那么属于这个 s 的 $E(s)$ 就会比较大, 反之就会由于 $\gamma \lambda$ 的存在衰减到 0.&lt;/p&gt;
&lt;p&gt;此外 Dutch traces, 当 visit 一个 state 时, 会在之前的基础上先做一个衰减, 然后再加 1. 还有一种是 replacing trace, 当 visit 一个 state 时, 直接会把 traces 置为 1 :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/09/3fkGFpXTNZbI16y.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;算法过程如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/09/cgxYrhOWDmblPs7.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;需要注意的是, 虽然当前遇到的 state 是 S, 内部的 for 循环在更新 $V(s)$ 的时候, 都使用同一个 $\delta \leftarrow R + \gamma V(S&apos;) - V(S) $ 对所有的 state 更新.
:::note
:::
这个被称为 Backward View of TD-$\lambda$, 是因为我们对每个 sate s 更新的时候, 是基于当前 state S 的 TD-error, 只不过同时还基于 state s 对应的 $E(s)$. 如果 state s 距当前 state S 很远(表现为出现次数很少, 因为我们按时序 visit state), 那么其 $E(s)$ 就会很小, 最后分配的权重就会很小:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/05/09/ePOLlfxAQijUC42.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;em&gt;source from refer &lt;a href=&quot;https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&quot;&gt;Reinforcement Learning: An Introduction&lt;/a&gt;&lt;/em&gt;&lt;/p&gt;
&lt;h4&gt;3.5.9 Equivalences of TD-$\lambda$&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;$\lambda = 0$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;forward view : 这个按定义的计算方式, 易知此时就是 TD(0).&lt;/p&gt;
&lt;p&gt;backward view : 当 $\lambda = 0$ 时, 只有当前 state S 对应的 $E(S) = 1$ , 其他的永远恒为 0. 此时就是 TD(0).&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;对于 Online(实时更新) 和 Offline(重复收集多个 episode 的数据, 最后进行 batch update) 的学习方式, 由于 $\lambda = 0$ 时, TD(0) 的 reward 计算方式 : $G_t = r_t + \gamma V(s_{t+1})$, 此时的 TD-error 就只关注相邻的 2 个 state, 因此最终结果是一样的.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;$\lambda = 1$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;forward view 如下:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
G &amp;amp;= \sum^{k \rightarrow \infty} \frac {1}{1 + \lambda + ... + \lambda^{k-1} } (G_t^1 + \lambda t^2 + \lambda^2 G_t^3 + ... + \lambda^{k-1} G_t^k)
\newline
&amp;amp;\approx  \frac {1}{N } \sum (G_t^1 +  G_t^2 +  G_t^3 + ... + G_t^N) \ \ (MC 方法)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;backward view 如下:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/ZBuCeAhLFRGHj1z.png&quot; alt=&quot;image.png&quot; /&gt;
&lt;img src=&quot;https://s2.loli.net/2024/05/10/W68YFtuprE15RXL.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;可以看到, 如果看直到 episode 结束的累计错误, 最后再进行 batch update, backward view 的 TD(1) 就是 MC error. 因此, $\lambda = 1 { &amp;amp; } \ \text{Offinline}$ 时, forward view 和 backward view 是等价的, 均为 MC 方法.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;但是, 当进行 Online 更新时, forward view 的 TD(1) 仍然是 MC 方法. 但是 backward view 的 TD(1) 不是了.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;general $\lambda$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/AM3xHzmPfY8EK4h.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/vTABbxalfpOFM7L.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;当 $\lambda \in (0,1)$ 结论同理. 同时上述推导过程还显式的证明了在 Offline 的更新方式下, forward view 和 back view 是等价的. 最终给出以下表格:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;另外 backward view 和 forward view 的等价证明也可以参考:http://www.incompleteideas.net/book/ebook/node76.html&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/10/NpAGwybmcLeRvH1.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;&lt;a href=&quot;https://www.davidsilver.uk/teaching/&quot;&gt;UCL Course on RL&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
</content:encoded></item><item><title>Causal Inference Series (I)</title><link>https://xuchenhui.cc/posts/2024-04-26-introduction-to-causal-inference-i/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-26-introduction-to-causal-inference-i/</guid><description>因果推断入门系列第一篇，基于 Brady Neal 的课程，从辛普森悖论出发揭示相关性与因果性的本质区别。</description><pubDate>Fri, 26 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;Causal Inference (因果推断) 已经在多个领域发挥出巨大作用, 尽管早已经听说过其大名, 但是从未步入这个领域好好学习一番, 通常是浅尝辄止. 为此在博客开一个系列, 一是用于记录学习, 二是希望能够起到监督作用...&lt;/p&gt;
&lt;p&gt;由于是入门学习, 因此课程和书籍选择了相对简单的. 根据网上的推荐和实际体验, 感觉 Brady Neal 的系列介绍比较合适, 因此这个系列都将会以 Brady Neal 的课程为基础. 课程链接: &lt;a href=&quot;https://www.bradyneal.com/causal-inference-course&quot;&gt;https://www.bradyneal.com/causal-inference-course&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;:::note
由于这是个人学习笔记, 我作为初学者, 在博客中记录的内容和理解难免会有错误. 希望各位能够指正, 并请不吝赐教, 在下将不胜感激.
:::&lt;/p&gt;
&lt;h2&gt;1. 第一章 Motivation: Why You Might Care&lt;/h2&gt;
&lt;p&gt;第一章主要介绍辛普森悖论, 以及向我们初步展示相关性和因果性的联系与区别.&lt;/p&gt;
&lt;h3&gt;1.1 Simpson’s Paradox&lt;/h3&gt;
&lt;p&gt;通常因果推断的第一课都是 Simpson’s Paradox (辛普森悖论) . 它说了这么一件事 : 假设现在有种病, 我们有 2 个治疗方案, treatment A and treatment B. 在做实验的时候, treatment B  比较稀缺, 只有较少的志愿者可以用上 B, 比如 treatment A and treatment B 的志愿者分别为 73% 和 27% . 现在得到这么一组数据 :&lt;/p&gt;
&lt;p&gt;表中, 百分比指的是接受相应的 treatment 后志愿者死亡率. Mild组 表示病的不重 , Severe组 表示病的比较严重.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/26/OvNDhwKrz68ajnu.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;从上表可以看到, 无论是哪个分组, 明显 treatment B 死亡率更低. 但是有趣的是, 当你纵观所有人, 即 Total 列反而是 treatment A 死亡率更低. 那么到底哪个 treatment 更好呢?&lt;/p&gt;
&lt;p&gt;上表有个关键的问题, 总共 550 个人 接受了 treatment B, 但是有 500 个是重病患者. 因此计算最终的死亡率时候, 重病死亡率的权重更大, 导致对于 treatment B 的 Total 死亡率接近 20 %. 同理, 对于 treatment A, 轻症患者更多, 所以最后的平均死亡率反而比较低. 所以, 到底哪个更好?? 实际上, 这个答案是基于因果关系的.&lt;/p&gt;
&lt;p&gt;如果受试者的 Condition 影响 treatment. 举个例子, 医生会根据患病情况来给出 treatment, 如果患病情况比较轻, 那么通常会安排 treatment A. 反之, 病重的会安排 treatment B. 他们之间的关系如下:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/26/I7lOEvoz3VsPmgQ.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;那这时, 这需要看不同患病情况下, treatment 的治愈率, 显然这种情况下 treatment B 更好.&lt;/p&gt;
&lt;p&gt;如果受试者的 treatment 影响 Condition. 举个例子, 比如 treatment B 比较牛逼但是稀缺, 本来患病了就直接用药即可, 但是由于人们非要等着 treatment B 导致病情恶化. 当然对于 treatment A 是没有这个问题的. 那么这是他们的关系就是:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/26/oEmGarx15iw4TNM.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;此时, 显然尽管 treatment B 药效好, 但是为了存活率, 我们应该选择 treatment A. 总的来说, 当我们有了因果关系之后, 就可以解决 Simpson’s Paradox 了.&lt;/p&gt;
&lt;h4&gt;1.2 Correlation Does Not Imply Causation&lt;/h4&gt;
&lt;p&gt;这是一个很关键的思想: $相关性 \neq 因果性$. 有个&quot;Nicolas Cage and Pool Drownings&quot;的例子, 说的是演员尼古拉斯凯奇和发生游泳溺水次数的相关性.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/26/q9rXRWn3k5ZhlKa.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;有人发现, 你用这个哥们儿出演电影次数和有人游泳溺水次数算线性相关性, 结果可能显示高度相关, 这明显是很离谱的事情. 显然他们并没有什么因果性.&lt;/p&gt;
&lt;p&gt;再看一个例子, 人们发现一个事情, 那些晚上很晚回来并且穿着鞋子睡觉的人, 第二天早上醒来会头痛. 事实确实发生了, 人们会说他们是相关的. 但是实际隐藏了一个条件, 这些晚上穿着鞋子睡觉的人, 大概率是喝酒喝醉了回来倒头就睡, 第二天头疼也八成是因为喝酒喝的. 我们称背后隐藏的这个条件为 &quot;&lt;strong&gt;confounder&lt;/strong&gt;&quot;.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/26/eLzqBX6R4pUKVNc.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;我们称 confounder 与 研究对象 的关联为 &quot;&lt;strong&gt;confounding association&lt;/strong&gt;&quot; . 如果想单纯探究 &quot;穿鞋睡觉 -&amp;gt; 第二天头疼&quot; 的因果关系, 我们就必须先断掉 confounder 的影响.&lt;/p&gt;
&lt;h2&gt;2. 第二章 Potential Outcomes&lt;/h2&gt;
&lt;p&gt;第二章主要介绍基础概念.&lt;/p&gt;
&lt;h3&gt;2.1 Potential Outcomes and Individual Treatment Effects&lt;/h3&gt;
&lt;p&gt;考虑一个例子：假设有个人心情不太好，有人想送给他一只狗.如果他接受了这只狗，那他可能会变得开心.但如果他拒绝了呢？他会不会继续感到不高兴呢？反过来想，如果他接受了这只狗，但他仍然感到不高兴，那么我怎么知道，如果不送给他，他是否会变得更开心呢？&lt;/p&gt;
&lt;h4&gt;2.1.1  Potential Outcomes&lt;/h4&gt;
&lt;p&gt;根据前面的分析，实际上针对某个人采取不同的处理方式会产生不同的结果，这是一个潜在的结论.我们在之后的讨论中, 称这个潜在的输出为 $\mathit {Y}$ .&lt;/p&gt;
&lt;p&gt;在上边的例子中, $\mathit {Y} = 1$ 表示高兴, $\mathit {Y} = 0$ 为不高兴.&lt;/p&gt;
&lt;p&gt;用 $\mathit {T}$ 表示 treatment 这个随机变量. $\mathit {T} = 1$ 表示接受狗子, $\mathit {T} = 0$ 表示不接受.&lt;/p&gt;
&lt;p&gt;使用 $\mathit {Y(1)}$ 表示接受狗子以后的潜在输出, $\mathit {Y(0)}$ 为采取不接受狗子后的输出.&lt;/p&gt;
&lt;h4&gt;2.1.2  Individual Treatment Effects&lt;/h4&gt;
&lt;p&gt;因为有很多人, 我们使用 $\tau_i \triangleq Y_i(1) - Y_i(0)$ 来评估某个个体采取 treatment 之后的潜在输出结果.&lt;/p&gt;
&lt;p&gt;:::note
你可以观察接受狗子之后, 观察 $\mathit {Y(1)}$. 反之, 你可以不接受狗子来观察  $\mathit {Y(0)}$. 但是你不能同时观察到 $\mathit {Y(1)}$ 和 $\mathit {Y(0)}$ !!!! 这个问题就是 &quot;&lt;strong&gt;Fundamental Problem of Causal Inference&lt;/strong&gt;&quot;
:::&lt;/p&gt;
&lt;h3&gt;2.2  Average Treatment Effects&lt;/h3&gt;
&lt;p&gt;因为每个人可能有些许差异, 实际要想客观的评估 treatment 的作用, 我们要对所有人求 treatment 期望:&lt;/p&gt;
&lt;p&gt;$$
\tau \triangleq \mathbb{E}[Y_i(1) - Y_i(0)] = \mathbb{E}[Y(1) - Y(0)]
$$&lt;/p&gt;
&lt;p&gt;但是上式由于 Fundamental Problem of Causal Inference, 实际上比较难做到计算. 参看下表:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/27/9vpO8bP5Uq2dYtS.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;200&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;当对个体 $i$ 采取了 treatment 0 的时候, 你只能观察到 $Y_i(0)$, 观察不到 $Y_i(1)$. 也就是说以下式子不成立:&lt;/p&gt;
&lt;p&gt;$$
\mathbb{E}[Y(1) - Y(0)] = \mathbb{E}[Y(1)] - \mathbb{E}[Y(0)] \neq \mathbb{E}[Y(1) | T = 1 ] - \mathbb{E}[Y(0) | T = 0]
$$&lt;/p&gt;
&lt;p&gt;可以看到, treatment 0 对应的集合只是一部分, 不能作为全部的结果, $\mathbb{E}[Y(1)]$ 理应是最右边的结果(Intervening).&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/27/vtybfejVDn4QNKA.png&quot; alt=&quot;image.png&quot; width=&quot;600&quot; height=&quot;400&quot; /&amp;gt;&lt;/p&gt;
&lt;h4&gt;2.2.1 Ignorability and Exchangeability&lt;/h4&gt;
&lt;p&gt;那么什么时候, 或者基于什么假设, 上式能够成立呢?&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Assumption 2.1 Ignorability / Exchangeability&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note&lt;/p&gt;
&lt;p&gt;$$
(Y(1) , Y(0)) \amalg T
$$&lt;/p&gt;
&lt;p&gt;:::&lt;/p&gt;
&lt;p&gt;当假设满足 Ignorability 的时候, 能够做到以下式子成立. 这里 Ignorability 指的是, 可以忽视缺失的数据.&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\mathbb{E}[Y(1) - Y(0)] &amp;amp;= \mathbb{E}[Y(1)] - \mathbb{E}[Y(0)]
\newline
&amp;amp;=\mathbb{E}[Y(1) \mid T = 1 ] - \mathbb{E}[Y(0) \mid T = 0] \ (Ignorability)
\newline
&amp;amp;=\mathbb{E}[Y \mid T = 1 ] - \mathbb{E}[Y \mid T = 0] \ (之后讨论)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;上式表明 &lt;code&gt;Y(1)&lt;/code&gt; 就只基于 &lt;code&gt;T = 1&lt;/code&gt; , 不受其他影响 , 即没有 confounder 的影响了. 如图:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/27/crp69WgbIqMOvX5.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;这个假设也叫 &lt;code&gt;Exchangeability&lt;/code&gt;, 表示说 $\mathbb{E}[Y(0) \mid T = 0] = \mathbb{E}[Y(0) \mid T = 1] = \mathbb{E}[Y(0) \mid t ]$ , 这其实就是说对于 group A 或者 group B, 把他们交换 treatment group 和 control group, 输出的结果只与 treatment 有关, 和 group A 或者 group B 没有关系 (尤其是 confounder). 也暗示着除了 treatment 的方式有区别, 不受其他影响.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Definition 2.1 Identifiability&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
causal quantity (e.g. $\mathbb{E}[Y(t)]$) is Identifiable if we can compute it from a purely statistical quantity (e.g. $\mathbb{E}[Y \mid t]$)
:::&lt;/p&gt;
&lt;p&gt;这个 Identifiability 是说, 我们可以用 $\mathbb{E}[Y \mid t]$ 代替 $\mathbb{E}[Y(t)]$.&lt;/p&gt;
&lt;h4&gt;2.2.2 Conditional Exchangeability and Unconfoundedness&lt;/h4&gt;
&lt;p&gt;实际中, 我们直接假设 group A 或者 group B 除了 treatment 的方式有区别, 不受其他影响. 但是这个不太现实, 明显是不合理的. 但是我们考虑, 如果可以控制一些条件, 让他们除了 treatment 方式有区别, 其他没有区别.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Assumption 2.2 Conditional Exchangeability / Unconfoundedness&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note&lt;/p&gt;
&lt;p&gt;$$
(Y(1) , Y(0)) \amalg T \mid X
$$&lt;/p&gt;
&lt;p&gt;:::&lt;/p&gt;
&lt;p&gt;当假设满足 Conditional Exchangeability 的时候, 换句话说, 我们控制了 confounder X, 使得 group 基于同样的 confounder , 那这时去做 treatment, 就实现了 treatment 直接作用于 outcome, 不会因为潜在的 confounder 影响 outcome. 如图所示:&lt;/p&gt;
&lt;p&gt;&amp;lt;div style=&quot;display: flex;&quot;&amp;gt;
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/27/HcwmUCh7oAPdjir.png&quot; alt=&quot;Image 1&quot; style=&quot;width: 100%;&quot;&amp;gt;
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/27/2Gt8Ugf9d6WSomB.png&quot; alt=&quot;Image 2&quot; style=&quot;width: 100%;&quot;&amp;gt;&lt;/p&gt;
&lt;p&gt;于是有以下公式成立:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\mathbb{E}[Y(1) - Y(0) \mid X] &amp;amp;= \mathbb{E}[Y(1) \mid X] - \mathbb{E}[Y(0) \mid X]
\newline
&amp;amp;=\mathbb{E}[Y(1) \mid T = 1, X] - \mathbb{E}[Y(0) \mid T = 0, X] \ (Ignorability)
\newline
&amp;amp;=\mathbb{E}[Y \mid T = 1, X ] - \mathbb{E}[Y \mid T = 0, X] \ (fix confounder)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;则:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\mathbb{E}[Y(1) - Y(0) ] &amp;amp;= \mathbb{E}_X[\mathbb{E}[Y(1) \mid X] - \mathbb{E}[Y(0) \mid X]]
\newline
&amp;amp;=\mathbb{E}_X[\mathbb{E}[Y(1) \mid T = 1, X] - \mathbb{E}[Y(0) \mid T = 0, X]] \ (Ignorability)
\newline
&amp;amp;=\mathbb{E}_X[\mathbb{E}[Y \mid T = 1, X ] - \mathbb{E}[Y \mid T = 0, X] ]\ (expect \ confounder)
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;Conditional exchangeability (Assumption 2.2) is a core assumption for
causal inference and goes by many names. For example, the following are reasonably commonly used to &lt;strong&gt;refer to the same assumption: unconfoundedness, conditional ignorability, no unobserved confounding,
selection on observables, no omitted variable bias&lt;/strong&gt;, etc.&lt;/p&gt;
&lt;p&gt;$\textit{We will use the name “unconfoundedness” a fair amount throughout this book.}$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Theorem 2.1  (Adjustment Formula) Given the assumptions of &lt;strong&gt;unconfoundedness&lt;/strong&gt;, &lt;strong&gt;positivity&lt;/strong&gt;, &lt;strong&gt;consistency&lt;/strong&gt;, and &lt;strong&gt;no interference&lt;/strong&gt;, we can identify the
average treatment effect:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note&lt;/p&gt;
&lt;p&gt;$$
\mathbb{E}[Y(1) - Y(0) ] = \mathbb{E}_X[\mathbb{E}[Y \mid T = 1, X ] - \mathbb{E}[Y \mid T = 0, X] ]
$$&lt;/p&gt;
&lt;p&gt;:::&lt;/p&gt;
&lt;p&gt;不过上述的式子还是有缺陷, 我们只是理想的假设 fixed confounder 是全部的, 但很多 confounder 都是潜在未知的, 我们实际不能保证 fix 住的 confounder 就是全部的, 这就会导致还是会有从 treatment -&amp;gt; confounder -&amp;gt; outcome 这条链路的影响存在.&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>Lang Chain Series</title><link>https://xuchenhui.cc/posts/2024-04-24-lang-chain-series/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-24-lang-chain-series/</guid><description>对 LangChain 框架的核心组件（Prompts、Models、Indexes 等）进行学习记录与解读，加入个人理解与代码注释，方便快速上手。</description><pubDate>Wed, 24 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;本篇 Blog 不是教程, 官方教程&lt;a href=&quot;https://python.langchain.com/docs/modules/&quot;&gt;1&lt;/a&gt;、&lt;a href=&quot;https://python.langchain.com/docs/get_started/introduction&quot;&gt;2&lt;/a&gt; 有时候比较抽象, 这里只是在学习 LangChain 过程中做个记录, 并加入自己的理解和注释. 当然这里也不进行过多的介绍, 直接进入对组件的学习.&lt;/p&gt;
&lt;h2&gt;1. Prompts&lt;/h2&gt;
&lt;p&gt;Prompts 通常用来&quot;调教&quot;LLM, 比如用来指定 LLM 的输出格式, 也可以用来给 LMM 一些例子让他参考等等.LangChain 目前提供了 4 种 Prompts template , 方便用户构造 Prompt.&lt;/p&gt;
&lt;h3&gt;1.1 PromptTemplate&lt;/h3&gt;
&lt;p&gt;PromptTemplate 是最简单, 最基本的一种 Template. API Reference:&lt;a href=&quot;https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.prompt.PromptTemplate.html&quot;&gt;PromptTemplate&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;官方给了 2 种方法来用 PromptTemplate.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;方法一(推荐)&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts import PromptTemplate

# Instantiation using from_template (recommended)
prompt = PromptTemplate.from_template(&quot;Say {foo}&quot;)
prompt.format(foo=&quot;bar&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;结果 : &lt;code&gt;Say bar&lt;/code&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;方法二&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts import PromptTemplate

# Instantiation using initializer
prompt = PromptTemplate(input_variables=[&quot;foo&quot;],  template=&quot;Say {foo}&quot;)
prompt.format(foo=&quot;bar&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;结果 : &lt;code&gt;Say bar&lt;/code&gt;&lt;/p&gt;
&lt;h3&gt;1.2 ChatPromptTemplate&lt;/h3&gt;
&lt;p&gt;ChatPromptTemplate 通常有 3 种规则: &quot;system&quot;, &quot;ai&quot; and &quot;human&quot;. API Reference:&lt;a href=&quot;https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.chat.ChatPromptTemplate.html&quot;&gt;ChatPromptTemplate&lt;/a&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
from langchain_core.prompts import ChatPromptTemplate

chat_template = ChatPromptTemplate.from_messages(
    [
        (&quot;system&quot;,  &quot;You are a helpful AI bot. Your name is {name}.&quot;),
        (&quot;human&quot;,  &quot;Hello,  how are you doing?&quot;),
        (&quot;ai&quot;,  &quot;I&apos;m doing well,  thanks!&quot;),
        (&quot;human&quot;,  &quot;{user_input}&quot;),
    ]
)

messages = chat_template.format_messages(name=&quot;Bob&quot;,  user_input=&quot;What is your name?&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
[SystemMessage(content=&apos;You are a helpful AI bot. Your name is Bob.&apos;),
 HumanMessage(content=&apos;Hello,  how are you doing?&apos;),
 AIMessage(content=&quot;I&apos;m doing well,  thanks!&quot;),
 HumanMessage(content=&apos;What is your name?&apos;)]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到, 实际是产生了 &lt;code&gt;SystemMessage&lt;/code&gt;, &lt;code&gt;HumanMessage&lt;/code&gt; and &lt;code&gt;AIMessage&lt;/code&gt; 共 3 种 Message. 与之对应的, 在构造这些 Message 时, 可以使用相应的 Template : &lt;code&gt;AIMessagePromptTemplate&lt;/code&gt;, &lt;code&gt;SystemMessagePromptTemplate&lt;/code&gt; and &lt;code&gt;HumanMessagePromptTemplate&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate

chat_template = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(&quot;你是一个{llm_type}&quot;),  # 使用 template构造 系统的 message
        (&apos;ai&apos;, &quot;很高兴帮助您.&quot;),  # 直接使用 role 构造 AI的 message
        HumanMessagePromptTemplate.from_template(&quot;{text}&quot;),
    ]
)
messages = chat_template.format_messages(llm_type=&quot;AI助手&quot;, text = &quot;1 + 1 = ?&quot;)
print(messages)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[SystemMessage(content=&apos;你是一个AI助手&apos;),
AIMessage(content=&apos;很高兴帮助您.&apos;),
 HumanMessage(content=&apos;1 + 1 = ?&apos;)]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 Example selectors&lt;/h3&gt;
&lt;p&gt;假设有一些例子, 我们希望 LLM 能够根据输入挑一些合适的例子出来, 方便我们后续的操作, 比如将他们放到一个 prompt 中. &lt;a href=&quot;https://python.langchain.com/docs/modules/model_io/prompts/example_selectors/&quot;&gt;API Reference&lt;/a&gt;&lt;/p&gt;
&lt;h4&gt;Select by length&lt;/h4&gt;
&lt;p&gt;简单来说, 这个 selector 是根据用户输入的语句长短选择合适的 example, 你输入的越短,他给的例子越多, 你输出的越长,他给的例子越少.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.example_selectors import LengthBasedExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

# Examples of a pretend task of creating antonyms.
examples = [
    {&quot;input&quot;: &quot;happy&quot;, &quot;output&quot;: &quot;sad&quot;},
    {&quot;input&quot;: &quot;tall&quot;, &quot;output&quot;: &quot;short&quot;},
    {&quot;input&quot;: &quot;energetic&quot;, &quot;output&quot;: &quot;lethargic&quot;},
    {&quot;input&quot;: &quot;sunny&quot;, &quot;output&quot;: &quot;gloomy&quot;},
    {&quot;input&quot;: &quot;windy&quot;, &quot;output&quot;: &quot;calm&quot;},
]

example_prompt = PromptTemplate(
    input_variables=[&quot;input&quot;, &quot;output&quot;],
    template=&quot;Input: {input}\nOutput: {output}&quot;,
)
example_selector = LengthBasedExampleSelector(
    # The examples it has available to choose from.
    examples=examples,

    # The PromptTemplate being used to format the examples.
    # 这个参数有些 selector 不需要, 有些是必须的, 请参考函数具体API
    example_prompt=example_prompt,

    # The maximum length that the formatted examples should be.
    # Length is measured by the get_text_length function below.
    max_length=25,
    # The function used to get the length of a string, which is used
    # to determine which examples to include. It is commented out because
    # it is provided as a default value if none is specified.
    # get_text_length: Callable[[str], int] = lambda x: len(re.split(&quot;\n| &quot;, x))
)
dynamic_prompt = FewShotPromptTemplate(
    # We provide an ExampleSelector instead of examples.
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=&quot;Give the antonym of every input&quot;,
    suffix=&quot;Input: {adjective}\nOutput:&quot;,
    input_variables=[&quot;adjective&quot;],
)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# An example with small input, so it selects all examples.
print(dynamic_prompt.format(adjective=&quot;big&quot;))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Give the antonym of every input

Input: happy
Output: sad

Input: tall
Output: short

Input: energetic
Output: lethargic

Input: sunny
Output: gloomy

Input: windy
Output: calm

Input: big
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# An example with long input, so it selects only one example.
long_string = &quot;big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else&quot;
print(dynamic_prompt.format(adjective=long_string))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 可以看到只给了一个example
Give the antonym of every input

Input: happy
Output: sad

Input: big and huge and massive and large and gigantic and tall and much much much much much bigger than everything else
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Select by maximal marginal relevance (MMR)&lt;/h4&gt;
&lt;p&gt;这个 selector 的思想是, 选择与当前输入 p 相似的(cosine similarity) example $q_j$ , 但是这个 $q_j$ 还要尽量与 example pool 中 $q_i$ 不要太相似, 这是为了多样性. &lt;a href=&quot;https://arxiv.org/pdf/2211.13892.pdf&quot;&gt;原始 paper&lt;/a&gt; 中的 next example 选择公式为:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/25/2bHXmLKFRyAkxl5.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;可以看到, 如果下一个 $q_j$ 和 当前的 $p$ 很相似, 但是和其他的 $q_i$ 也非常相似, 那么这个分数也不会太高.&lt;/p&gt;
&lt;p&gt;例子&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import (
    MaxMarginalRelevanceExampleSelector,
    SemanticSimilarityExampleSelector,
)
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_openai import OpenAIEmbeddings

example_prompt = PromptTemplate(
    input_variables=[&quot;input&quot;, &quot;output&quot;],
    template=&quot;Input: {input}\nOutput: {output}&quot;,
)

# Examples of a pretend task of creating antonyms.
examples = [
    {&quot;input&quot;: &quot;happy&quot;, &quot;output&quot;: &quot;sad&quot;},
    {&quot;input&quot;: &quot;tall&quot;, &quot;output&quot;: &quot;short&quot;},
    {&quot;input&quot;: &quot;energetic&quot;, &quot;output&quot;: &quot;lethargic&quot;},
    {&quot;input&quot;: &quot;sunny&quot;, &quot;output&quot;: &quot;gloomy&quot;},
    {&quot;input&quot;: &quot;windy&quot;, &quot;output&quot;: &quot;calm&quot;},
]
example_selector = MaxMarginalRelevanceExampleSelector.from_examples(
    # The list of examples available to select from.
    examples,
    # The embedding class used to produce embeddings which are used to measure semantic similarity.
    OpenAIEmbeddings(),
    # The VectorStore class that is used to store the embeddings and do a similarity search over.
    FAISS,
    # The number of examples to produce.
    k=2,
)
mmr_prompt = FewShotPromptTemplate(
    # We provide an ExampleSelector instead of examples.
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=&quot;Give the antonym of every input&quot;,
    suffix=&quot;Input: {adjective}\nOutput:&quot;,
    input_variables=[&quot;adjective&quot;],
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# Input is a feeling, so should select the happy/sad example as the first one
print(mmr_prompt.format(adjective=&quot;worried&quot;))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Give the antonym of every input

Input: happy
Output: sad

Input: windy
Output: calm

Input: worried
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Select by similarity&lt;/h4&gt;
&lt;p&gt;这个就很单纯了, 直接使用 cos similarity 来选择最佳的 example, 典型 selector 是 SemanticSimilarityExampleSelector.&lt;/p&gt;
&lt;p&gt;例子&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_openai import OpenAIEmbeddings

example_prompt = PromptTemplate(
    input_variables=[&quot;input&quot;, &quot;output&quot;],
    template=&quot;Input: {input}\nOutput: {output}&quot;,
)

# Examples of a pretend task of creating antonyms.
examples = [
    {&quot;input&quot;: &quot;happy&quot;, &quot;output&quot;: &quot;sad&quot;},
    {&quot;input&quot;: &quot;tall&quot;, &quot;output&quot;: &quot;short&quot;},
    {&quot;input&quot;: &quot;energetic&quot;, &quot;output&quot;: &quot;lethargic&quot;},
    {&quot;input&quot;: &quot;sunny&quot;, &quot;output&quot;: &quot;gloomy&quot;},
    {&quot;input&quot;: &quot;windy&quot;, &quot;output&quot;: &quot;calm&quot;},
]
example_selector = SemanticSimilarityExampleSelector.from_examples(
    # The list of examples available to select from.
    examples,
    # The embedding class used to produce embeddings which are used to measure semantic similarity.
    OpenAIEmbeddings(),
    # The VectorStore class that is used to store the embeddings and do a similarity search over.
    Chroma,
    # The number of examples to produce.
    k=1,
)
similar_prompt = FewShotPromptTemplate(
    # We provide an ExampleSelector instead of examples.
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=&quot;Give the antonym of every input&quot;,
    suffix=&quot;Input: {adjective}\nOutput:&quot;,
    input_variables=[&quot;adjective&quot;],
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# Input is a feeling, so should select the happy/sad example
print(similar_prompt.format(adjective=&quot;worried&quot;))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Give the antonym of every input

Input: happy
Output: sad

Input: worried
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Select by n-gram overlap&lt;/h4&gt;
&lt;p&gt;这个是计算输入的 query 和 example 的 similarity(0-1 之间), 然后根据给定的阈值, 给出满足条件的 example.&lt;/p&gt;
&lt;p&gt;阈值为 0.0 表示只排除不相关的 example. 阈值为 -1.0, 表示所有的 example 都会返回. 大于 1.0 表示不返回 example, 默认为 -1.0&lt;/p&gt;
&lt;p&gt;例子&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.example_selector.ngram_overlap import (
    NGramOverlapExampleSelector,
)
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate(
    input_variables=[&quot;input&quot;, &quot;output&quot;],
    template=&quot;Input: {input}\nOutput: {output}&quot;,
)

# Examples of a fictional translation task.
examples = [
    {&quot;input&quot;: &quot;See Spot run.&quot;, &quot;output&quot;: &quot;Ver correr a Spot.&quot;},
    {&quot;input&quot;: &quot;My dog barks.&quot;, &quot;output&quot;: &quot;Mi perro ladra.&quot;},
    {&quot;input&quot;: &quot;Spot can run.&quot;, &quot;output&quot;: &quot;Spot puede correr.&quot;},
]
example_selector = NGramOverlapExampleSelector(
    # The examples it has available to choose from.
    examples=examples,
    # The PromptTemplate being used to format the examples.
    example_prompt=example_prompt,
    # The threshold, at which selector stops.
    # It is set to -1.0 by default.
    threshold=-1.0,
    # For negative threshold:
    # Selector sorts examples by ngram overlap score, and excludes none.
    # For threshold greater than 1.0:
    # Selector excludes all examples, and returns an empty list.
    # For threshold equal to 0.0:
    # Selector sorts examples by ngram overlap score,
    # and excludes those with no ngram overlap with input.
)
dynamic_prompt = FewShotPromptTemplate(
    # We provide an ExampleSelector instead of examples.
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=&quot;Give the Spanish translation of every input&quot;,
    suffix=&quot;Input: {sentence}\nOutput:&quot;,
    input_variables=[&quot;sentence&quot;],
)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# An example input with large ngram overlap with &quot;Spot can run.&quot;
# and no overlap with &quot;My dog barks.&quot;
print(dynamic_prompt.format(sentence=&quot;Spot can run fast.&quot;))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 可以看到, 即使 &quot;My dog barks.&quot; 与输入不相关, 但还是输出了
Give the Spanish translation of every input

Input: Spot can run.
Output: Spot puede correr.

Input: See Spot run.
Output: Ver correr a Spot.

Input: My dog barks.
Output: Mi perro ladra.

Input: Spot can run fast.
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输入:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 不输出 不相关的
example_selector.threshold = 0.0
print(dynamic_prompt.format(sentence=&quot;Spot can run fast.&quot;))

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Give the Spanish translation of every input

Input: Spot can run.
Output: Spot puede correr.

Input: See Spot run.
Output: Ver correr a Spot.

Input: Spot plays fetch.
Output: Spot juega a buscar.

Input: Spot can run fast.
Output:
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 Few-shot prompt templates&lt;/h3&gt;
&lt;p&gt;上边介绍了一些 example selector , 现在介绍
FewShotPromptTemplate + example selector .
二者实现的功能就是, 首先我们有一堆 example (通常是成对儿的输入和输出) , 然后可以自适应的, 根据不同的输入, 能够自动的 select 合适的 example 与 输入合并, 一同变为 LLM 的 prompt.&lt;/p&gt;
&lt;p&gt;也就是说, 不同的输入, 会产生不同的 prompt , 我理解是和 Retrieval-augmented generation (RAG) 类似的效果.&lt;/p&gt;
&lt;p&gt;define example:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

examples = [
    {
        &quot;question&quot;: &quot;Who lived longer, Muhammad Ali or Alan Turing?&quot;,
        &quot;answer&quot;: &quot;&quot;&quot;
Are follow up questions needed here: Yes.
Follow up: How old was Muhammad Ali when he died?
Intermediate answer: Muhammad Ali was 74 years old when he died.
Follow up: How old was Alan Turing when he died?
Intermediate answer: Alan Turing was 41 years old when he died.
So the final answer is: Muhammad Ali
&quot;&quot;&quot;,
    },
    {
        &quot;question&quot;: &quot;When was the founder of craigslist born?&quot;,
        &quot;answer&quot;: &quot;&quot;&quot;
Are follow up questions needed here: Yes.
Follow up: Who was the founder of craigslist?
Intermediate answer: Craigslist was founded by Craig Newmark.
Follow up: When was Craig Newmark born?
Intermediate answer: Craig Newmark was born on December 6, 1952.
So the final answer is: December 6, 1952
&quot;&quot;&quot;,
    },
    {
        &quot;question&quot;: &quot;Who was the maternal grandfather of George Washington?&quot;,
        &quot;answer&quot;: &quot;&quot;&quot;
Are follow up questions needed here: Yes.
Follow up: Who was the mother of George Washington?
Intermediate answer: The mother of George Washington was Mary Ball Washington.
Follow up: Who was the father of Mary Ball Washington?
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
So the final answer is: Joseph Ball
&quot;&quot;&quot;,
    },
    {
        &quot;question&quot;: &quot;Are both the directors of Jaws and Casino Royale from the same country?&quot;,
        &quot;answer&quot;: &quot;&quot;&quot;
Are follow up questions needed here: Yes.
Follow up: Who is the director of Jaws?
Intermediate Answer: The director of Jaws is Steven Spielberg.
Follow up: Where is Steven Spielberg from?
Intermediate Answer: The United States.
Follow up: Who is the director of Casino Royale?
Intermediate Answer: The director of Casino Royale is Martin Campbell.
Follow up: Where is Martin Campbell from?
Intermediate Answer: New Zealand.
So the final answer is: No
&quot;&quot;&quot;,
    },
]

# format example pool
example_prompt = PromptTemplate(
    input_variables=[&quot;question&quot;, &quot;answer&quot;], template=&quot;Question: {question}\n{answer}&quot;
)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;define selector:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    # This is the list of examples available to select from.
    examples,
    # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
    OpenAIEmbeddings(),
    # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
    Chroma,
    # This is the number of examples to produce.
    k=1,
)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;FewShotPromptTemplate + example + selector :&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    suffix=&quot;Question: {input}&quot;,
    input_variables=[&quot;input&quot;],
)

print(prompt.format(input=&quot;Who was the father of Mary Ball Washington?&quot;))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 可以看到 prompt 最后只引入了一个 example

Question: Who was the maternal grandfather of George Washington?

Are follow up questions needed here: Yes.
Follow up: Who was the mother of George Washington?
Intermediate answer: The mother of George Washington was Mary Ball Washington.
Follow up: Who was the father of Mary Ball Washington?
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
So the final answer is: Joseph Ball


Question: Who was the father of Mary Ball Washington?
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.5 Partial prompt templates&lt;/h3&gt;
&lt;p&gt;这个功能就是类似函数的参数具有默认值.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts import PromptTemplate

prompt = PromptTemplate.from_template(&quot;{foo} {bar}&quot;)
partial_prompt = prompt.partial(foo=&quot;666&quot;)
print(partial_prompt.format(bar=&quot;baz&quot;)) # 输出 666 baz
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个也可以使用函数:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from datetime import datetime

def _get_datetime():
    now = datetime.now()
    return now.strftime(&quot;%m/%d/%Y, %H:%M:%S&quot;)

prompt = PromptTemplate(
    template=&quot;Tell me a {adjective} joke about the day {date}&quot;,
    input_variables=[&quot;adjective&quot;, &quot;date&quot;],
)
# date 默认 调用当前时间
partial_prompt = prompt.partial(date=_get_datetime)
# 输出 : Tell me a funny joke about the day 04/25/2024, 23:22:18
print(partial_prompt.format(adjective=&quot;funny&quot;))

&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.6 PipelinePrompt&lt;/h3&gt;
&lt;p&gt;PipelinePrompt 能够把多个 prompt 整到一起.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

# 最终的 prompt
full_template = &quot;&quot;&quot;{introduction}

{example}

{start}&quot;&quot;&quot;
full_prompt = PromptTemplate.from_template(full_template)

# 用于 Introduction
introduction_template = &quot;&quot;&quot;You are impersonating {person}.&quot;&quot;&quot;
introduction_prompt = PromptTemplate.from_template(introduction_template)

# example
example_template = &quot;&quot;&quot;Here&apos;s an example of an interaction:

Q: {example_q}
A: {example_a}&quot;&quot;&quot;
example_prompt = PromptTemplate.from_template(example_template)

# 用户的输入
start_template = &quot;&quot;&quot;Now, do this for real!

Q: {input}
A:&quot;&quot;&quot;
start_prompt = PromptTemplate.from_template(start_template)

# 将上边的 prompt 合并到一起
input_prompts = [
    (&quot;introduction&quot;, introduction_prompt),
    (&quot;example&quot;, example_prompt),
    (&quot;start&quot;, start_prompt),
]
# 指定谁是谁
pipeline_prompt = PipelinePromptTemplate(
    final_prompt=full_prompt, pipeline_prompts=input_prompts
)
print(
    pipeline_prompt.format(
        person=&quot;Elon Musk&quot;,
        example_q=&quot;What&apos;s your favorite car?&quot;,
        example_a=&quot;Tesla&quot;,
        input=&quot;What&apos;s your favorite social media site?&quot;,
    )
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出 :&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
You are impersonating Elon Musk.

Here&apos;s an example of an interaction:

Q: What&apos;s your favorite car?
A: Tesla

Now, do this for real!

Q: What&apos;s your favorite social media site?
A:

&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;2. Retrieval&lt;/h2&gt;
&lt;p&gt;Retrieval Augmented Generation (RAG) 可能是目前 LLM 发挥比较大作用的一个应用. 其核心思想是利用外挂的知识库赋予在不同的垂直领域应用能力.&lt;/p&gt;
&lt;p&gt;其核心流程如下:&lt;/p&gt;
&lt;p&gt;[1] 首先我们要有相应的资源库, source&lt;/p&gt;
&lt;p&gt;[2] 然后针对不同的资源, 我们使用相应的 dataloader 将资源数据读取&lt;/p&gt;
&lt;p&gt;[3] 由于资源文档比较长, 通常我们要进行分块, 称为 chunk&lt;/p&gt;
&lt;p&gt;[4] 将文档 chunk 后, 会对每个 chunk 进行 embedding&lt;/p&gt;
&lt;p&gt;[5] embedding 之后, 要进行 store, 这个组件一般称为 vector store&lt;/p&gt;
&lt;p&gt;[6] 当我们给定输入的时候, LLM 能够根据语义从 store 中抽取有用的资源,这个过程就是 retrieve&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/26/IiHL1MVWNc8QJtG.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h3&gt;2.1 Dataloader&lt;/h3&gt;
&lt;p&gt;&lt;a href=&quot;https://python.langchain.com/docs/integrations/document_loaders/&quot;&gt;官方文档&lt;/a&gt;集成了很多第三方 dataloader,
甚至可以直接从 arxiv、GitHub 等直接获取数据. 但是常用的可能就是针对 文本 和 csv 的, 而且使用方法类似, 所以这里只学习 文本类型 的.&lt;/p&gt;
&lt;h4&gt;Document Loader&lt;/h4&gt;
&lt;p&gt;LangChain 给的例子是继承 BaseLoader, 然后将读到的文本初始化为 Document 对象.
内部有 4 个基本方法: 直接读取所有, 异步读取所有, lazy 读取, 异步 lazy 读取.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/26/PC8fvdnsaHA9KB6.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;p&gt;&amp;lt;details markdown=&quot;1&quot;&amp;gt;
&amp;lt;summary&amp;gt; 详细代码 &amp;lt;/summary&amp;gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from typing import AsyncIterator, Iterator
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document

class CustomDocumentLoader(BaseLoader):
    &quot;&quot;&quot;An example document loader that reads a file line by line.&quot;&quot;&quot;

    def __init__(self, file_path: str) -&amp;gt; None:
        &quot;&quot;&quot;Initialize the loader with a file path.

        Args:
            file_path: The path to the file to load.
        &quot;&quot;&quot;
        self.file_path = file_path

    def lazy_load(self) -&amp;gt; Iterator[Document]:  # &amp;lt;-- Does not take any arguments
        &quot;&quot;&quot;A lazy loader that reads a file line by line.

        When you&apos;re implementing lazy load methods, you should use a generator
        to yield documents one by one.
        &quot;&quot;&quot;
        with open(self.file_path, encoding=&quot;utf-8&quot;) as f:
            line_number = 0
            for line in f:
                yield Document(
                    page_content=line,
                    metadata={&quot;line_number&quot;: line_number, &quot;source&quot;: self.file_path},
                )
                line_number += 1

    # alazy_load is OPTIONAL.
    # If you leave out the implementation, a default implementation which delegates to lazy_load will be used!
    async def alazy_load(
        self,
    ) -&amp;gt; AsyncIterator[Document]:  # &amp;lt;-- Does not take any arguments
        &quot;&quot;&quot;An async lazy loader that reads a file line by line.&quot;&quot;&quot;
        # Requires aiofiles
        # Install with `pip install aiofiles`
        # https://github.com/Tinche/aiofiles
        import aiofiles

        async with aiofiles.open(self.file_path, encoding=&quot;utf-8&quot;) as f:
            line_number = 0
            async for line in f:
                yield Document(
                    page_content=line,
                    metadata={&quot;line_number&quot;: line_number, &quot;source&quot;: self.file_path},
                )
                line_number += 1

with open(&quot;./meow.txt&quot;, &quot;w&quot;, encoding=&quot;utf-8&quot;) as f:
    quality_content = &quot;meow meow🐱 \n meow meow🐱 \n meow😻😻&quot;
    f.write(quality_content)

loader = CustomDocumentLoader(&quot;./meow.txt&quot;)

## Test out the lazy load interface
for doc in loader.lazy_load():
    print()
    print(type(doc))
    print(doc)

## Test out the async implementation
async for doc in loader.alazy_load():
    print()
    print(type(doc))
    print(doc)

&quot;&quot;&quot;
输出结果一样:

&amp;lt;class &apos;langchain_core.documents.base.Document&apos;&amp;gt;
page_content=&apos;meow meow🐱 \n&apos; metadata={&apos;line_number&apos;: 0, &apos;source&apos;: &apos;./meow.txt&apos;}

&amp;lt;class &apos;langchain_core.documents.base.Document&apos;&amp;gt;
page_content=&apos; meow meow🐱 \n&apos; metadata={&apos;line_number&apos;: 1, &apos;source&apos;: &apos;./meow.txt&apos;}

&amp;lt;class &apos;langchain_core.documents.base.Document&apos;&amp;gt;
page_content=&apos; meow😻😻&apos; metadata={&apos;line_number&apos;: 2, &apos;source&apos;: &apos;./meow.txt&apos;}
&quot;&quot;&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&amp;lt;/details&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;load() can be helpful in an interactive environment such as a jupyter notebook.
Avoid using it for production code since eager loading assumes that all the content can fit into memory, which is not always the case, especially for enterprise data.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;2.2 Text Splitters&lt;/h3&gt;
&lt;p&gt;Once you&apos;ve loaded documents, you&apos;ll often want to transform them to better suit your application. The simplest example is you may want to split a long document into smaller chunks that can fit into your model&apos;s context window.&lt;/p&gt;
&lt;p&gt;&lt;a href=&quot;https://python.langchain.com/docs/modules/data_connection/document_transformers/&quot;&gt;官方文档&lt;/a&gt;给了多种 Splitter, 最常用的为以下 3 种.&lt;/p&gt;
&lt;h4&gt;Split by character&lt;/h4&gt;
&lt;p&gt;这个是最简单的 Splitter, 单纯就是使用指定的 character 去切分 text .&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_text_splitters import CharacterTextSplitter

text_splitter = CharacterTextSplitter(
    separator=&quot; &quot;, # 指定 separator
    chunk_size=5,
    chunk_overlap=2,
    length_function=len,
    is_separator_regex=False,
)

text_splitter.split_text(&quot;我是练习时长 达两年半的坤坤&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;上述代码中, chunk_size 本意是指进行 split 之后, 后续进行 store 的时候, 最大以多大的 size 作为一个整体进行存储. 但是可以看到这个参数对于 CharacterTextSplitter 不生效, 实际上&lt;a href=&quot;https://api.python.langchain.com/en/latest/_modules/langchain_text_splitters/character.html#CharacterTextSplitter&quot;&gt;源码&lt;/a&gt;中, 就是简单的用 re.split() 对文档按照 separator 进行切割, 不管子句有多长, 直接返回.&lt;/p&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[&apos;我是练习时长&apos;, &apos;达两年半的坤坤&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;因为 LLM 通常对输入有长度限制, 因此 CharacterTextSplitter 不太适合, 可能会超出输入尺寸范围, 而下边的 RecursiveCharacterTextSplitter 可以递归切割子句, 直到每个子句都小于 chunk size.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;Recursive Splitter By Character&lt;/h4&gt;
&lt;p&gt;这个看了&lt;a href=&quot;https://api.python.langchain.com/en/latest/_modules/langchain_text_splitters/character.html#RecursiveCharacterTextSplitter&quot;&gt;源码&lt;/a&gt;,它默认的 &lt;code&gt;separator = [&quot;\n\n&quot;, &quot;\n&quot;, &quot; &quot;, &quot;&quot;]&lt;/code&gt;, 我的理解是说, 首先会根据&lt;code&gt;\n\n&lt;/code&gt;进行切割. 一般来说, 2 个换行分割开的通常是 2 篇文章. 所以会先按照这个尺度进行切割. 如果切割之后, 某篇文章还是太长(大于 chunk size), 那么会继续使用 &lt;code&gt;\n&lt;/code&gt; 进行划分切割, 同理直到其长度小于 chunk size.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size=5,
    chunk_overlap=0,
    length_function=len,
    is_separator_regex=False,
    keep_separator = False
)
text_splitter.split_text(
    &quot;我是\n\n练习时长达两年\n半的坤坤&quot;
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[&apos;我是&apos;, &apos;练习时长达&apos;, &apos;两年&apos;, &apos;半的坤坤&apos;]
# 可以看到, 首先用`\n\n`进行分割, 因为&quot;我是&quot;的长度小于5, 所以直接存起来,
# 但是后边部分太长, 又基于`\n`进行切割, &quot;半的坤坤&quot;是满足要求的,所以一起存了起来.
# 但是&quot;练习时长达两年&quot;的长度还是大于5, 于是进行了继续的切割. 变为&quot;练习时长达&quot; 和 &quot;两年&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Split by tokens&lt;/h4&gt;
&lt;p&gt;这个就是使用 NLP 中 token 进行切割, 不同的 tokenizer 有不同的切割方式. 举个例子, 如果一个单词算一个 token, 那就按单词切割.
为什么需要这个呢, 就是有些 LLM 的输入具有 token 数目的限制, 因此最好分割存储的 tokenizer 和 LLM 使用一样的.&lt;/p&gt;
&lt;p&gt;这里使用 OpenAI BPE tokenizer : tiktoken, 是&lt;a href=&quot;https://huggingface.co/learn/nlp-course/chapter6/5&quot;&gt;BPE 算法&lt;/a&gt;的一个实现.&lt;/p&gt;
&lt;p&gt;Split by tokens 的使用方法基于上边 2 种 Splitter, 只是切割时调用方法不同:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# pip install tiktoken
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
    encoding=&quot;cl100k_base&quot;, chunk_size=100, chunk_overlap=0
) # CharacterTextSplitter 实际不受 chunk_size 的约束
texts = text_splitter.split_text(&quot;text&quot;)

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    model_name=&quot;gpt-4&quot;,
    chunk_size=100,
    chunk_overlap=0,
) # 能够保证子句全部小于chunk_size
texts = text_splitter.split_text(&quot;text&quot;)
# 此外 encoding参数 和 model_name参数 效果类似, 具体请参考api
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Semantic Chunking&lt;/h4&gt;
&lt;p&gt;这个就是字面意思, 基于 text 之间的语义进行切割, 使得语义相近的尽量在一个 chunk, 但是这个目前(2024 年 4 月 27 日)是个实验性功能. 参考&lt;a href=&quot;https://python.langchain.com/docs/modules/data_connection/document_transformers/semantic-chunker/&quot;&gt;官方文档&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;由于这个需要计算语义相似度, 所以需要进行 embedding.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
text_splitter = SemanticChunker(OpenAIEmbeddings())
texts = text_splitter.split_text(&quot;text&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个里边提供了一个 Breakpoints, 用于评估什么时候该切割, 语义相近多少才算相近?&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Percentile&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Percentile(百分位数) 是默认的评估标准, 他是计算所有两两句子之间的 difference, 如果大于阈值就给他切开.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;text_splitter = SemanticChunker(
    OpenAIEmbeddings(), breakpoint_threshold_type=&quot;percentile&quot;
    # breakpoint_threshold_amount : 默认值
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;阅读&lt;a href=&quot;https://api.python.langchain.com/en/latest/_modules/langchain_experimental/text_splitter.html#SemanticChunker&quot;&gt;源码&lt;/a&gt;可以看到,当&lt;code&gt;threshold_type = &quot;percentile&quot;&lt;/code&gt; 时, 默认使用 95% 分位数. &lt;code&gt;breakpoint_threshold_amount&lt;/code&gt; 参数控制分位数具体大小.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Standard Deviation&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;用法类似, 不再赘述. 源码中当&lt;code&gt;threshold_type = &quot;standard_deviation&quot;&lt;/code&gt; 时, 默认使用 &lt;code&gt;mean + 3 * std&lt;/code&gt; 作为阈值. &lt;code&gt;breakpoint_threshold_amount&lt;/code&gt; 参数控制标准差的倍数.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Interquartile&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;使用的箱线图方法, 默认使用 &lt;code&gt;mean + 1.5 * iqr&lt;/code&gt;, 其中 &lt;code&gt;iqr = q3 - q1&lt;/code&gt;, q3 为 75% 分位数, q1 为 25% 分位数. &lt;code&gt;breakpoint_threshold_amount&lt;/code&gt; 参数控制&lt;code&gt;q3 - q1&lt;/code&gt;的倍数.&lt;/p&gt;
&lt;h3&gt;2.3 Embedding&lt;/h3&gt;
&lt;p&gt;&lt;a href=&quot;https://python.langchain.com/docs/integrations/text_embedding/&quot;&gt;官方文档&lt;/a&gt;给了很多第三方 embedding  方法. 其实就是训练好的一个 Matrix. 这里使用 openAI 提供的 embedding.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_openai import OpenAIEmbeddings
embeddings_model = OpenAIEmbeddings(api_key=&quot;...&quot;)
embeddings = embeddings_model.embed_documents(
    [
        &quot;Hi there!&quot;,
        &quot;Oh, hello!&quot;,
        &quot;What&apos;s your name?&quot;,
        &quot;My friends call me World&quot;,
        &quot;Hello World!&quot;
    ]
)
len(embeddings), len(embeddings[0])
embedded_query = embeddings_model.embed_query(&quot;What was the name mentioned in the conversation?&quot;)
embedded_query[:5]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;[0.0053587136790156364,
 -0.0004999046213924885,
 0.038883671164512634,
 -0.003001077566295862,
 -0.00900818221271038]
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Caching&lt;/h4&gt;
&lt;p&gt;在得到 embedding 之后, 我们可以已经 embedding 过的 token 给他缓存, 如果后续又来了同一个 token, 我们可以直接从 cache 调用, 而不去需要从 embedding matrix 获取.&lt;/p&gt;
&lt;p&gt;核心组件为 CacheBackedEmbeddings, 使用例子如下:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter

underlying_embeddings = OpenAIEmbeddings()

store = LocalFileStore(&quot;./cache/&quot;) # 表示缓存到本地

cached_embedder = CacheBackedEmbeddings.from_bytes_store(
    underlying_embeddings, store, namespace=underlying_embeddings.model
)
&quot;&quot;&quot;
underlying_embedder: The embedder to use for embedding.
document_embedding_cache: Any ByteStore for caching document embeddings.
batch_size: (optional, defaults to None) The number of documents to embed between store updates.
namespace: (optional, defaults to &quot;&quot;) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, you can set it to the name of the embedding model used.
&quot;&quot;&quot;
raw_documents = TextLoader(&quot;../../state_of_the_union.txt&quot;).load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)
%%time
db = FAISS.from_documents(documents, cached_embedder)
# 输出 CPU times: user 218 ms, sys: 29.7 ms, total: 248 ms
# Wall time: 1.02 s

%%time
db2 = FAISS.from_documents(documents, cached_embedder)
# 输出 CPU times: user 15.7 ms, sys: 2.22 ms, total: 18 ms
# Wall time: 17.2 ms

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最后, store 可以换, 比如使用 memory store:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import InMemoryByteStore
store = InMemoryByteStore()
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
    underlying_embeddings, store, namespace=underlying_embeddings.model
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.4 Vector stores&lt;/h3&gt;
&lt;p&gt;&lt;a href=&quot;https://python.langchain.com/docs/integrations/vectorstores/&quot;&gt;官方文档&lt;/a&gt;提供了许多第三方的 Vector stores.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/27/gBZulMS2hGEKJTi.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;Facebook AI Similarity Search (FAISS) library, 例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.document_loaders import TextLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS

# Load the document, split it into chunks, embed each chunk and load it into the vector store.
raw_documents = TextLoader(&apos;../../../state_of_the_union.txt&apos;).load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)
db = FAISS.from_documents(documents, OpenAIEmbeddings())
query = &quot;What did the president say about Ketanji Brown Jackson&quot;
docs = db.similarity_search(query)
print(docs[0].page_content)
&quot;&quot;&quot;
也可以直接使用 vector 进行 search
embedding_vector = OpenAIEmbeddings().embed_query(query)
docs = db.similarity_search_by_vector(embedding_vector)
print(docs[0].page_content) # 输出结果是一样的
&quot;&quot;&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;    Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.

    Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.

    One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

    And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Asynchronous operations&lt;/h4&gt;
&lt;p&gt;Vector Store 也支持 异步操作, &lt;code&gt;Qdrant&lt;/code&gt; is a vector store, which supports all the async operations, thus it will be used in this walkthrough.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# pip install qdrant-client
from langchain_community.vectorstores import Qdrant
db = await Qdrant.afrom_documents(documents, embeddings, &quot;http://localhost:6333&quot;)
query = &quot;What did the president say about Ketanji Brown Jackson&quot;
docs = await db.asimilarity_search(query)
print(docs[0].page_content)
&quot;&quot;&quot;
# 同理支持 vector 查询
embedding_vector = embeddings.embed_query(query)
docs = await db.asimilarity_search_by_vector(embedding_vector)
&quot;&quot;&quot;

&quot;&quot;&quot;
# 此外计算 similarity  的时候, 支持 Maximum marginal relevance search (MMR)方法:
query = &quot;What did the president say about Ketanji Brown Jackson&quot;
found_docs = await qdrant.amax_marginal_relevance_search(query, k=2, fetch_k=10)
for i, doc in enumerate(found_docs):
    print(f&quot;{i + 1}.&quot;, doc.page_content, &quot;\n&quot;)
&quot;&quot;&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;    Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.

    Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.

    One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

    And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.5 Retrievers&lt;/h3&gt;
&lt;p&gt;Retrievers 接受用户的 query, 然后从 vector store 中根据规则(不同种类相似度)去搜索得到合适的上下文, 用于后续回答输出.&lt;/p&gt;
&lt;p&gt;同样的, &lt;a href=&quot;https://python.langchain.com/docs/modules/data_connection/retrievers/&quot;&gt;官方文档&lt;/a&gt;有多种类型的 Retrievers, 下面简要学习.&lt;/p&gt;
&lt;h4&gt;Vector store-backed retriever&lt;/h4&gt;
&lt;p&gt;这个 retriever 是最简单的, 他使用的 search 方法有 similarity search and MMR.&lt;/p&gt;
&lt;p&gt;:::note
后续其他的高级 retriever 都是基于这个 retriever进行的包装. 都有一个参数 : base_retriever = retriever
:::&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
loader = TextLoader(&quot;../../state_of_the_union.txt&quot;)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever()
# retriever = db.as_retriever(search_type=&quot;mmr&quot;)
# retriever = db.as_retriever(
#     search_type=&quot;similarity_score_threshold&quot;, search_kwargs={&quot;score_threshold&quot;: 0.5}
# )
# retriever = db.as_retriever(search_kwargs={&quot;k&quot;: 1})

docs = retriever.invoke(&quot;what did he say about ketanji brown jackson&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;MultiQueryRetriever&lt;/h4&gt;
&lt;p&gt;前边提到的最简单的 retriever, 将用户输入的 query 对于 sotre 中的文本进行相似度计算, 但是有时候输入的 query 可能并不太明确, 导致搜索到的文本不够清晰. 这时可以使用 MultiQueryRetriever, 这个 retriever 内部使用一个 LLM 基于用户输入的 query 进行分析, 输出逻辑性的过渡问题, 这样每个问题都会分别去与 sotre 中的资源计算 similarity. 通过对同一问题生成多个视角，MultiQueryRetriever 或许能够克服基于距离的检索的一些限制，并获得更丰富的结果集。&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# Build a sample vectorDB
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_openai import ChatOpenAI

# Load blog post
loader = WebBaseLoader(&quot;https://lilianweng.github.io/posts/2023-06-23-agent/&quot;)
data = loader.load()
# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
splits = text_splitter.split_documents(data)
# VectorDB
embedding = OpenAIEmbeddings()
vectordb = Chroma.from_documents(documents=splits, embedding=embedding)

question = &quot;What are the approaches to Task Decomposition?&quot;
llm = ChatOpenAI(temperature=0) # 指定一个 LLM 基于 query 生成多角度的 query
retriever_from_llm = MultiQueryRetriever.from_llm(
    retriever=vectordb.as_retriever(), llm=llm
)
unique_docs = retriever_from_llm.invoke(question)
len(unique_docs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 可以看到日志默认生成了3个问题(注意这不是最终输出哈~)
[&apos;1. How can Task Decomposition be approached?&apos;,
 &apos;2. What are the different methods for Task Decomposition?&apos;,
 &apos;3. What are the various approaches to decomposing tasks?&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Contextual Compression Retriever&lt;/h4&gt;
&lt;p&gt;这个实际上算是一个 wrapper, 本意是用来解决:因为我们不知道用户到低想搜索什么, 所以以会直接放大量的文档给 store 中, 但是这就会导致一个问题, 当我们输入 query 的时候, 有用的信息可能会被淹没在大量的文档中, 这就需要我们对文档信息进行压缩, 把没用的信息过滤掉.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
def pretty_print_docs(docs):
    print(
        f&quot;\n{&apos;-&apos; * 100}\n&quot;.join(
            [f&quot;Document {i+1}:\n\n&quot; + d.page_content for i, d in enumerate(docs)]
        )
    )
documents = TextLoader(&quot;../../state_of_the_union.txt&quot;).load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()

docs = retriever.invoke(&quot;What did the president say about Ketanji Brown Jackson&quot;)
pretty_print_docs(docs)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Document 1:

Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.

Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.

One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.
----------------------------------------------------------------------------------------------------
Document 2:

A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.

And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.

We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling.

We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers.

We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster.

We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.
----------------------------------------------------------------------------------------------------
Document 3:

And for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong.

As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential.

While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.

And soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things.

So tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together.

First, beat the opioid epidemic.
----------------------------------------------------------------------------------------------------
Document 4:

Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers.

And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up.

That ends on my watch.

Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect.

We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees.

Let’s pass the Paycheck Fairness Act and paid leave.

Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty.

Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到由于存储的时候, chunk size 比较大, 并且我们要找的信息就仅仅为一句话(淹没在文档中), 所以简单的使用 retriever 会直接将相关的文档全部返回了. 在上边的基础上, 我们对基础的 retriever 进行 warpper :&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_openai import OpenAI

llm = OpenAI(temperature=0)
compressor = LLMChainExtractor.from_llm(llm)
# 用于将抽到的 doc 进行 compression, 并从每个文档中仅提取与查询相关的内容。
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

compressed_docs = compression_retriever.invoke(
    &quot;What did the president say about Ketanji Jackson Brown&quot;
)
pretty_print_docs(compressed_docs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;但是这个 compression 会使用 LLM 对抽回的文档进行处理压缩(万一他处理的不好呢?). 官方提供了一种可以不改变原始文档, 但能保留核心信息的 : filters.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.retrievers.document_compressors import LLMChainFilter

_filter = LLMChainFilter.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=_filter, base_retriever=retriever
)

compressed_docs = compression_retriever.invoke(
    &quot;What did the president say about Ketanji Jackson Brown&quot;
)
pretty_print_docs(compressed_docs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Document 1:

Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.

Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.

One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;不过这个 filter 在过滤的时候, 是把整个文本再吃进去操作, 可能带来更多的 token 计算量. EmbeddingsFilter 可以直接使用 embedding 进行操作.&lt;/p&gt;
&lt;p&gt;例子:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=embeddings_filter, base_retriever=retriever
)

compressed_docs = compression_retriever.invoke(
    &quot;What did the president say about Ketanji Jackson Brown&quot;
)
pretty_print_docs(compressed_docs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最后, 官方提供一个 Pipeline 能够把 splitter, embedding, filter, retriever 整到一起.&lt;/p&gt;
&lt;p&gt;例子&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_text_splitters import CharacterTextSplitter

splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=&quot;. &quot;)
# EmbeddingsRedundantFilter 内部实现对文本的 embedding 和 去重冗余
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
# filter
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
# 组合
pipeline_compressor = DocumentCompressorPipeline(
    transformers=[splitter, redundant_filter, relevant_filter]
)
# creat retriever
compression_retriever = ContextualCompressionRetriever(
    base_compressor=pipeline_compressor, base_retriever=retriever # 当然要基于基本的 retriever
)

compressed_docs = compression_retriever.invoke(
    &quot;What did the president say about Ketanji Jackson Brown&quot;
)
pretty_print_docs(compressed_docs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Document 1:

One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.

And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson
----------------------------------------------------------------------------------------------------
Document 2:

As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential.

While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year
----------------------------------------------------------------------------------------------------
Document 3:

A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder
----------------------------------------------------------------------------------------------------
Document 4:

Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.

And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.

We can do both
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;3. Tools&lt;/h2&gt;
&lt;p&gt;&lt;a href=&quot;https://python.langchain.com/v0.1/docs/modules/tools/&quot;&gt;Tools&lt;/a&gt; 通常和 Agent 搭配使用. 面对复杂的任务, Agent(通常是一个LLM) 通过上下文信息去使用合适的 tool 以完成任务. 一个 tool 通常有以下几个组件:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;The name of the tool&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;A description of what the tool is&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;JSON schema of what the inputs to the tool are&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;The function to call&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Whether the result of a tool should be returned directly to the user&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;The name, description, and JSON schema can be used to prompt the LLM so it knows how to specify what action to take. The function to call is equivalent to taking that action.&lt;/p&gt;
&lt;p&gt;:::note
Importantly, the name, description, and JSON schema (if used) are &lt;strong&gt;all used in the prompt&lt;/strong&gt;.
:::&lt;/p&gt;
&lt;h3&gt;Basic Usage Tutorial&lt;/h3&gt;
&lt;p&gt;这里使用官方的 Wikipedia tool.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper

api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
tool = WikipediaQueryRun(api_wrapper=api_wrapper)
pritn(tool.name,tool.description,tool.args,tool.return_direct)
tool.run({&quot;query&quot;: &quot;langchain&quot;}) # 或者 tool.run(&quot;langchain&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;&apos;Page: LangChain\nSummary: LangChain is a framework designed to simplify the creation of applications &apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;Custom Tools&lt;/h3&gt;
&lt;p&gt;官方提供了一些&lt;a href=&quot;https://python.langchain.com/v0.1/docs/integrations/tools/&quot;&gt;第三方 tool&lt;/a&gt; 可供使用. 不过通常我们会定义自己的 tool.&lt;/p&gt;
&lt;p&gt;定义一个 tool 的时候, 需要指定以下信息:&lt;/p&gt;
&lt;p&gt;name(required) : 一个 agent 可调用的 tools 的名字需要是 unique 的.&lt;/p&gt;
&lt;p&gt;description(recommended) : 告知 agent 这个 tool 的用途是什么.&lt;/p&gt;
&lt;p&gt;args_schema(recommended) : &lt;a href=&quot;https://docs.pydantic.dev/latest/&quot;&gt;Pydantic&lt;/a&gt; 类型, 可以用于类型检测, 也可以用于添加一些额外的信息, 最后都会作为 prompt 的一部分.&lt;/p&gt;
&lt;h4&gt;decorator&lt;/h4&gt;
&lt;pre&gt;&lt;code&gt;# Import things that are needed generically
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool

@tool
def search(query: str) -&amp;gt; str:
    &quot;&quot;&quot;Look up things online.&quot;&quot;&quot;
    return &quot;LangChain&quot;
print(search.name)
print(search.description)
print(search.args)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;search
search(query: str) -&amp;gt; str - Look up things online.
{&apos;query&apos;: {&apos;title&apos;: &apos;Query&apos;, &apos;type&apos;: &apos;string&apos;}}

&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code&gt;
@tool
def multiply(a: int, b: int) -&amp;gt; int:
    &quot;&quot;&quot;Multiply two numbers.&quot;&quot;&quot;
    return a * b

print(multiply.name)
print(multiply.description)
print(multiply.args)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;multiply
multiply(a: int, b: int) -&amp;gt; int - Multiply two numbers.
{&apos;a&apos;: {&apos;title&apos;: &apos;A&apos;, &apos;type&apos;: &apos;integer&apos;}, &apos;b&apos;: {&apos;title&apos;: &apos;B&apos;, &apos;type&apos;: &apos;integer&apos;}}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以使用 @tool 的参数指定 tool 的名字, 或者对参数进行描述, 提供额外的信息.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;class SearchInput(BaseModel):
    query: str = Field(description=&quot;should be a search query&quot;)


@tool(&quot;search-tool&quot;, args_schema=SearchInput, return_direct=True)
def search(query: str) -&amp;gt; str:
    &quot;&quot;&quot;Look up things online.&quot;&quot;&quot;
    return &quot;LangChain&quot;
print(search.name)
print(search.description)
print(search.args)
print(search.return_direct)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
search-tool
search-tool(query: str) -&amp;gt; str - Look up things online.
{&apos;query&apos;: {&apos;title&apos;: &apos;Query&apos;, &apos;description&apos;: &apos;should be a search query&apos;, &apos;type&apos;: &apos;string&apos;}}
True

&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Subclassing  BaseTool (&lt;strong&gt;Recommend&lt;/strong&gt;)&lt;/h4&gt;
&lt;p&gt;更加的自定义化, 但是稍微麻烦一些.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
from typing import Optional, Type
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)

class SearchInput(BaseModel):
    query: str = Field(description=&quot;should be a search query&quot;)

class CustomSearchTool(BaseTool):
    name = &quot;custom_search&quot;
    description = &quot;useful for when you need to answer questions about current events&quot;
    args_schema: Type[BaseModel] = SearchInput # 指定 schema

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -&amp;gt; str:
        &quot;&quot;&quot;Use the tool.&quot;&quot;&quot;
        return &quot;LangChain&quot;

    async def _arun(
        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
    ) -&amp;gt; str:
        &quot;&quot;&quot;Use the tool asynchronously.&quot;&quot;&quot;
        raise NotImplementedError(&quot;custom_search does not support async&quot;)


class CalculatorInput(BaseModel):
    a: int = Field(description=&quot;first number&quot;)
    b: int = Field(description=&quot;second number&quot;)

class CustomCalculatorTool(BaseTool):
    name = &quot;Calculator&quot;
    description = &quot;useful for when you need to answer questions about math&quot;
    args_schema: Type[BaseModel] = CalculatorInput
    return_direct: bool = True

    def _run(
        self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -&amp;gt; str:
        &quot;&quot;&quot;Use the tool.&quot;&quot;&quot;
        return a * b

    async def _arun(
        self,
        a: int,
        b: int,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -&amp;gt; str:
        &quot;&quot;&quot;Use the tool asynchronously.&quot;&quot;&quot;
        raise NotImplementedError(&quot;Calculator does not support async&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;search = CustomSearchTool()
print(search.name)
print(search.description)
print(search.args)
&apos;&apos;&apos;
custom_search
useful for when you need to answer questions about current events
{&apos;query&apos;: {&apos;title&apos;: &apos;Query&apos;, &apos;description&apos;: &apos;should be a search query&apos;, &apos;type&apos;: &apos;string&apos;}}
&apos;&apos;&apos;

multiply = CustomCalculatorTool()
print(multiply.name)
print(multiply.description)
print(multiply.args)
print(multiply.return_direct)
&apos;&apos;&apos;
Calculator
useful for when you need to answer questions about math
{&apos;a&apos;: {&apos;title&apos;: &apos;A&apos;, &apos;description&apos;: &apos;first number&apos;, &apos;type&apos;: &apos;integer&apos;}, &apos;b&apos;: {&apos;title&apos;: &apos;B&apos;, &apos;description&apos;: &apos;second number&apos;, &apos;type&apos;: &apos;integer&apos;}}
True
&apos;&apos;&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;第三方的 tool 基本都是上述的方法, 这里以 &lt;a href=&quot;https://python.langchain.com/v0.1/docs/integrations/tools/arxiv/&quot;&gt;ArXiv tool&lt;/a&gt; 为例. 其源代码见: &lt;a href=&quot;https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/tools/arxiv/tool.py&quot;&gt;arxiv_tool.py&lt;/a&gt;, &lt;a href=&quot;https://github.com/langchain-ai/langchain/blob/480c02bf553de894cedc60504b126807dd6dea00/libs/community/langchain_community/utilities/arxiv.py#L13&quot;&gt;utilities_tool.py&lt;/a&gt;. 核心代码如下:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;# 在 arxiv_tool.py 文件中

&quot;&quot;&quot;Tool for the Arxiv API.&quot;&quot;&quot;
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_community.utilities.arxiv import ArxivAPIWrapper

class ArxivInput(BaseModel):
    &quot;&quot;&quot;Input for the Arxiv tool.&quot;&quot;&quot;
    query: str = Field(description=&quot;search query to look up&quot;) # 用于指定 args_schema

class ArxivQueryRun(BaseTool):
    &quot;&quot;&quot;Tool that searches the Arxiv API.&quot;&quot;&quot;
    name: str = &quot;arxiv&quot;
    description: str = (
        &quot;A wrapper around Arxiv.org &quot;
        &quot;Useful for when you need to answer questions about Physics, Mathematics, &quot;
        &quot;Computer Science, Quantitative Biology, Quantitative Finance, Statistics, &quot;
        &quot;Electrical Engineering, and Economics &quot;
        &quot;from scientific articles on arxiv.org. &quot;
        &quot;Input should be a search query.&quot;
    )
    api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper) # 这里做了一个 Wrapper
    args_schema: Type[BaseModel] = ArxivInput

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -&amp;gt; str:
        &quot;&quot;&quot;Use the Arxiv tool.&quot;&quot;&quot;
        return self.api_wrapper.run(query) # 调用的是 Wrapper 的 run

# ===================================
# 在 utilities_tool.py 中

&quot;&quot;&quot;Util that calls Arxiv.&quot;&quot;&quot;
&quot;&quot;&quot;Util that calls Arxiv.&quot;&quot;&quot;
import logging
import os
import re
from typing import Any, Dict, Iterator, List, Optional

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)


class ArxivAPIWrapper(BaseModel):
    # 其他函数省略, 就是一些 import 导入检测之类的.
    def run(self, query: str) -&amp;gt; str:

        try:
            if self.is_arxiv_identifier(query):
                results = self.arxiv_search(
                    id_list=query.split(),
                    max_results=self.top_k_results,
                ).results()
            else:
                results = self.arxiv_search(  # type: ignore
                    query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
                ).results()
        except self.arxiv_exceptions as ex:
            return f&quot;Arxiv exception: {ex}&quot;
        docs = [
            f&quot;Published: {result.updated.date()}\n&quot;
            f&quot;Title: {result.title}\n&quot;
            f&quot;Authors: {&apos;, &apos;.join(a.name for a in result.authors)}\n&quot;
            f&quot;Summary: {result.summary}&quot;
            for result in results
        ]
        if docs:
            return &quot;\n\n&quot;.join(docs)[: self.doc_content_chars_max]
        else:
            return &quot;No good Arxiv Result was found&quot;

&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;StructuredTool dataclass&lt;/h4&gt;
&lt;p&gt;StructuredTool  方法将上述 &quot;直接定义方法&quot; 与 &quot;SubClass&quot; 的方法结合起来. 阅读&lt;a href=&quot;https://github.com/langchain-ai/langchain/blob/b53548dcda8fe1dd820f7db31db6b1f3bff6c360/libs/core/langchain_core/tools.py#L702&quot;&gt;其源代码&lt;/a&gt;, StructuredTool 内部也是继承的 Base tool.&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
class CalculatorInput(BaseModel):
    a: int = Field(description=&quot;first number&quot;)
    b: int = Field(description=&quot;second number&quot;)

def multiply(a: int, b: int) -&amp;gt; int:
    &quot;&quot;&quot;Multiply two numbers.&quot;&quot;&quot;
    return a * b


calculator = StructuredTool.from_function(
    func=multiply,
    name=&quot;Calculator&quot;,
    description=&quot;multiply numbers&quot;,
    args_schema=CalculatorInput,
    return_direct=True,
    # coroutine= ... &amp;lt;- you can specify an async method if desired as well
)

print(calculator.name)
print(calculator.description)
print(calculator.args)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
Calculator
Calculator(a: int, b: int) -&amp;gt; int - multiply numbers
{&apos;a&apos;: {&apos;title&apos;: &apos;A&apos;, &apos;description&apos;: &apos;first number&apos;, &apos;type&apos;: &apos;integer&apos;}, &apos;b&apos;: {&apos;title&apos;: &apos;B&apos;, &apos;description&apos;: &apos;second number&apos;, &apos;type&apos;: &apos;integer&apos;}}

&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;Handling Tool Errors&lt;/h4&gt;
&lt;pre&gt;&lt;code&gt;from langchain_core.tools import ToolException

def search_tool1(s: str):
    raise ToolException(&quot;The search tool1 is not available.&quot;)

search = StructuredTool.from_function(
    func=search_tool1,
    name=&quot;Search_tool1&quot;,
    description=&quot;A bad tool&quot;,
)

search.run(&quot;test&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/05/13/u2lFsqzKwxGUrnX.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;设置 handle_tool_error = True, 可以将 ToolException 的字符串输出:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;search = StructuredTool.from_function(
    func=search_tool1,
    name=&quot;Search_tool1&quot;,
    description=&quot;A bad tool&quot;,
    handle_tool_error=True,
)

search.run(&quot;test&quot;) # 输出 &apos;The search tool1 is not available.&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;也可以将 handle_tool_error 设置为一个函数, 这个函数必须接受一个 &lt;code&gt;ToolException&lt;/code&gt;, 然后给一个字符串输出 &lt;code&gt;str&lt;/code&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;def _handle_error(error: ToolException) -&amp;gt; str:
    return (
        &quot;The following errors occurred during tool execution:&quot;
        + error.args[0]
        + &quot;Please try another tool.&quot;
    )

search = StructuredTool.from_function(
    func=search_tool1,
    name=&quot;Search_tool1&quot;,
    description=&quot;A bad tool&quot;,
    handle_tool_error=_handle_error,
)

search.run(&quot;test&quot;)
# 输出 : &apos;The following errors occurred during tool execution:The search tool1 is not available.Please try another tool.&apos;

&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;4. Agent&lt;/h2&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>L1 and L2 Regularization</title><link>https://xuchenhui.cc/posts/2024-04-20-l1-and-l2-regularization/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-20-l1-and-l2-regularization/</guid><description>从多个角度探讨 L1 和 L2 正则化的原理，解释其为何能有效防止模型过拟合，涵盖公式推导、几何解释和贝叶斯视角。</description><pubDate>Sat, 20 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;在机器学习或深度学习中，无论是分类、回归还是其他场景，通常都是利用模型去拟合一个函数。在这个过程中，正则化是一种常用的手段，用来防止过拟合。本篇博客主要从几个角度探讨正则化的理解，并解释它为何能够防止过拟合。&lt;/p&gt;
&lt;p&gt;:::note
阅读前, 需要你 : 有高数基础知识, 线代基础知识, 统计学习基础知识, 当然还要有 ML和 DL 的知识背景.
:::&lt;/p&gt;
&lt;h2&gt;1. 公式&lt;/h2&gt;
&lt;p&gt;给定输入 $x_1,x_2...x_n$ 和输出 $y_1,y_2...y_n$，我们通过一个模型 $f(w,x)$ 来映射输入输出之间的关系，其中 $w$ 表示模型参数。参数的求解通过优化以下损失函数：&lt;/p&gt;
&lt;p&gt;$$
L = \sum_{i} L(x_i,y_i)  + R(w)
$$&lt;/p&gt;
&lt;p&gt;这里 $R(w)$ 是关于参数 $w$ 的一个函数.对于 L1 Regularization&lt;/p&gt;
&lt;p&gt;$$
R(w) = \lambda {|w|_1}^2
$$&lt;/p&gt;
&lt;p&gt;对于 L2 Regularization&lt;/p&gt;
&lt;p&gt;$$
R(w) = \lambda {|w|_2}^2
$$&lt;/p&gt;
&lt;h2&gt;2. 理解&lt;/h2&gt;
&lt;blockquote&gt;
&lt;p&gt;从式子上看, Regularization 看起来就是想让参数 $w$ 的范数小一点 , 下面来看为什么 $w$ 的范数小一点, 就能减缓过拟合.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;首先我们来看过拟合是什么? 定义这里就不说了, 直观看个图吧.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/20/csCq1bnfWRQ7mg4.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;上图中,我们有蓝色和红色,2组类别的数据点, 想训练一个分类器f(w,x)去将蓝色点和红色点分开.&lt;/p&gt;
&lt;p&gt;可以看到, 绿色的线($f_1$)近乎完美的对数据进行了拟合, 黑色($f_2$)的看起来差一些.&lt;/p&gt;
&lt;p&gt;:::note
但是啊, 我是说有没有一种可能, 这个数据集他有异常点(比如加粗的那几个), 如果你拟合的太好, 反而会把噪声也拟合了, 导致你的模型泛化性能不好. 反观黑色的线, 就看起来更加不错.
:::&lt;/p&gt;
&lt;p&gt;那么如何才能让模型从绿色变成黑色的线呢? 即怎么把函数的&quot;弯弯绕绕&quot;给他拿走.&lt;/p&gt;
&lt;p&gt;我们对函数 $f(x)$ 在某个点进行泰勒展开:&lt;/p&gt;
&lt;p&gt;$$
f(w,x) = f(w,a) + f&apos;(w,a)(x - a) + \frac{f&apos;&apos;(w,a)}{2!}(x - a)^2 + \cdots
$$&lt;/p&gt;
&lt;p&gt;可以看到, 一个函数的复杂度(就是&quot;弯弯绕绕&quot;), 其实来自于它的高阶项 $ f^n(w,a)(x - a)^n$ . 比如 二次函数就1个弯, 三次函数就2个弯了, 同理次幂越高,&quot;弯弯绕绕&quot;越多. 因此想把高阶项拿掉, 其实可以让其系数 : $f^n(w,a) -&amp;gt; 0$ , 而系数正好就是 $w$ 的函数.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;我们有理由相信,如果 $w$ 不是很大的情况下, $f(w)^n$ 应该不会大到哪里去.于是就把 $w$ 的范数加到loss中, 去让 $w$ 小一点.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;3. 等价形式&lt;/h2&gt;
&lt;h3&gt;3.1 给权重 $w$ 加约束&lt;/h3&gt;
&lt;blockquote&gt;
&lt;p&gt;让 $w$ 小一点等价于让 $w$ 不太大 - 鲁迅&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;所以优化目标可以变为:&lt;/p&gt;
&lt;p&gt;$$
minimize \ L(w,x) , \ s.t. {|w|_2}^2 \leq C
$$&lt;/p&gt;
&lt;p&gt;使用拉格朗日乘数法, 上述问题变为:&lt;/p&gt;
&lt;p&gt;$$
\mathop{minimize}\limits_{w} \  \mathop{maximize}\limits_{\lambda} \ L(w,\lambda,x) = L(w) + \lambda ( {|w|_2}^2 -  C)
$$&lt;/p&gt;
&lt;p&gt;剩下过程就是,求导等于0, 然后计算相应的 $w$ 和 $\lambda$ 即可. 不过这里想说的是, 在对 $w$ 求导的时候, 你会发现其实并没有 $C$ 的事情 :&lt;/p&gt;
&lt;p&gt;$$
\frac{\partial J}{\partial w} = \frac{\partial L}{\partial w} + 2 * \lambda w
$$&lt;/p&gt;
&lt;p&gt;于是不妨直接 $minimize$ 下式:&lt;/p&gt;
&lt;p&gt;$$
minimize \ L(w,x) + \lambda {|w|_2}^2
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;1范数同理, 不再赘述.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;3.2 让权重 $w$ 衰减&lt;/h3&gt;
&lt;p&gt;$$
minimize \ J = \ L(w,x) + \lambda {|w|_2}^2
$$&lt;/p&gt;
&lt;p&gt;梯度下降:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
w &amp;amp;= w - \eta ( \frac{\partial L}{\partial w} - 2 * \lambda w) \
&amp;amp;= (1 - 2 * \lambda *  \eta ) w - \frac{\partial L}{\partial w} \
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;当 $2 \lambda \eta \in (0,1)$ 时，每次更新权重都是在上一次权重衰减后的基础上进行的。&lt;/p&gt;
&lt;h3&gt;3.3 给权重 $w$ 限定分布&lt;/h3&gt;
&lt;p&gt;从统计学上来看, $f(w,x)$ 输出的是一个分布去拟合 y 的分布 , 使用贝叶斯公式:&lt;/p&gt;
&lt;p&gt;$$
p(w|x,y) = \frac{p(w) * p(x,y|w)}{p(x,y)}
$$&lt;/p&gt;
&lt;p&gt;$p(x,y)$ 是死的，$maximize \ p(w|x,y)$ 就是 $maximize$ 分子&lt;/p&gt;
&lt;p&gt;极大似然估计核心公式为:&lt;/p&gt;
&lt;p&gt;$$
\mathop{arg \ max}\limits_{w}\ p(w|x,y) = \mathop{arg \ max}\limits_{w} \ p(x,y|w)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;极大似然估计不关心 w 的原始分布. 它的核心思想是，假设数据是由参数 w 生成的，那么反过来，能让根据这些数据计算出的 w 的条件分布, 最大的那个 w 就是我们要找的 w.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;最大后验估计核心公式为:&lt;/p&gt;
&lt;p&gt;$$
\mathop{arg \ max}\limits_{w} \ p(w|x,y) = \mathop{arg \ max}\limits_{w} \ p(x,y|w) * p(w)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;最大后验估计对极大似然估计说: 老弟你这不对, 分子最大化的时候 , 你得考虑 p(w) .&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;OK , 基于最大后验估计, 取 log 得到:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\mathop{arg \ max}\limits_{w} \ p(x,y|w) * p(w) &amp;amp;=  log \ p(x,y|w) +  log \ p(w)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;我们不看前半部分,只看后半部分.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;假设 $w \sim N(0 , \sigma ^ 2)$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
f(w) = \frac {1} {\sqrt {2 \pi \sigma}} exp(- \frac{w^2}{2 \sigma ^2})
$$&lt;/p&gt;
&lt;p&gt;则&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
maximize \   log \ p(w) &amp;amp;= \
&amp;amp;=  maximize - \frac {1} {2 \sigma ^2}  {|w|_2}^2 + C \
&amp;amp;= minimize  \ \frac {1} {2 \sigma ^2}  {|w|_2}^2 + C \
&amp;amp;\equiv minimize \  {|w|_2}^2 \ (\sigma = 1)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;:::note
从这个角度可以看到, 如果加 L2 Regularization , 其实就是对 model 的权重参数 $w$ 假定了先验分布为&lt;strong&gt;标准正态分布&lt;/strong&gt;.
:::&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;假设 $w \sim Laplace(0 , b)$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
f(w) = \frac {1} {2b} exp(- \frac{|w|}{b})
$$&lt;/p&gt;
&lt;p&gt;则&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
maximize \   log \ p(w) &amp;amp;= \
&amp;amp;=  maximize - \frac {1} {2 b}  {|w|_1}^2 + C \
&amp;amp;= minimize  \frac {1} {2 b}  {|w|_1}^2 + C \
&amp;amp;&amp;lt;=&amp;gt; minimize  \   {|w|_1}^2   \ (b = 1)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;:::note
从这个角度可以看到, 如果加 L1 Regularization, 其实就是对 model的权重参数 $w$ 假定了先验分布为&lt;strong&gt;拉普拉斯分布&lt;/strong&gt;.
:::&lt;/p&gt;
&lt;h2&gt;4. 区别&lt;/h2&gt;
&lt;h3&gt;4.1 函数性质&lt;/h3&gt;
&lt;p&gt;我们可以从标准正态分布和拉普拉斯分布的函数性质,来窥探L1 Regularization 和 L2 Regularization 的区别.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/20/nhvpas6JESRMAUf.png&quot; alt=&quot;untitled.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;根据上图可以看到, L1 Regularization (拉普拉斯分布) 在 0 附近形状更尖锐, 将 w 推向0的时候更加强硬. 而  L2 Regularization (标准正态分布) 显得更加柔和.&lt;/p&gt;
&lt;h3&gt;4.2 几何性质&lt;/h3&gt;
&lt;p&gt;此外也可以从几何性质上对 L1 Regularization 和 L2 Regularization 进行分析.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://miro.medium.com/v2/resize:fit:1600/format:webp/1*_e8BLNA749W_7yxi7hz-DA.gif&quot; alt=&quot;image.gif&quot; /&gt;&lt;/p&gt;
&lt;p&gt;1范数在几何上表现为一个高维的四方体,2范数则是一个高维的球体. 可以从上图看到,在做minimize时候,L1 Regularization 的 &quot;尖儿&quot; 更容易触到靠内的等高线,即 &quot;尖儿&quot;的位置具有更低的值, 而 &quot;尖儿&quot;的位置,就意味着 w 的某个分量就是0. 而2范数因为整个表面都是外凸出的弧,在哪个地方都有可能取得最小值.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://miro.medium.com/v2/resize:fit:1400/format:webp/1*GdOo-X5Mq2CYLzci6reoZw.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;这也就是为什么说, L1 Regularization 能够比 L2 Regularization 更加的 &quot;Sparsity&quot;.所以 L1 正则项的另外一个应用就是能够进行特征选择: &lt;a href=&quot;https://en.wikipedia.org/wiki/Lasso_(statistics)&quot;&gt;LASSO回归&lt;/a&gt;通过在原始损失函数上添加 L1 Regularization,导致特征 $i$ 对应的权重 $w_i$ 为 0, 我们认为, 权重 $w_i=0$ 的特征就是可以去除的.&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://satishkumarmoparthi.medium.com/why-l1-norm-creates-sparsity-compared-with-l2-norm-3c6fa9c607f4&quot;&gt;Why L1 norm creates Sparsity compared with L2 norm&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://en.wikipedia.org/wiki/Regularization_(mathematics)&quot;&gt;Regularization Wiki&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>AdmaW(part I) Weight Decay == L2 Regularization?</title><link>https://xuchenhui.cc/posts/2024-04-20-weight-decay-and-l2-regularization/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-20-weight-decay-and-l2-regularization/</guid><description>探讨 SGD 与 Adam 优化器下 Weight Decay 和 L2 正则化的等价性差异，引入 AdamW 优化器的设计动机与原理。</description><pubDate>Sat, 20 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;在 &lt;a href=&quot;https://chenhui-x.github.io/posts/L1-and-L2-Regularization/#%E8%AE%A9%E6%9D%83%E9%87%8D-w-%E8%A1%B0%E5%87%8F&quot;&gt;上一篇 Blog&lt;/a&gt; 中探讨了 L1 Regularization 和 L2 Regularization. 我们说到: 对损失函数添加 L2 Regularization , 最后对 w 使用梯度下降的时候, 实际是对 w 做了权重衰减.&lt;/p&gt;
&lt;p&gt;然而, 上述等价性只在优化器为随机梯度下降（SGD）时成立(下边我们会证明). 在其他情况下, 特别是在训练深度学习模型时, 经常使用&lt;a href=&quot;https://arxiv.org/abs/1412.6980&quot;&gt;Adam&lt;/a&gt;优化器 , 上述结论不成立.&lt;/p&gt;
&lt;p&gt;本篇 Blog 主要探讨在使用 Adam 的时候 Weight Decay 和 L2 Regularization 的关系, 以及当更新参数引入 momentum之后他们之间的关系 , 最后介绍 AdamW 优化器. 文中符号都尽量与 &lt;a href=&quot;https://arxiv.org/abs/1711.05101&quot;&gt;AdamW paper&lt;/a&gt; 中的一致.&lt;/p&gt;
&lt;p&gt;:::note
阅读前, 需要你 : 有高数基础知识, 线代基础知识, 当然还要有 ML和 DL 的知识背景.
:::&lt;/p&gt;
&lt;h2&gt;1. SGD场景下&lt;/h2&gt;
&lt;h3&gt;1.1 无 momentum&lt;/h3&gt;
&lt;p&gt;weight decay 的公式:&lt;/p&gt;
&lt;p&gt;$$
\theta_{t+1} = (1 - \lambda ) \theta_{t} - \alpha \nabla f_t(\theta_{t})
$$&lt;/p&gt;
&lt;p&gt;这里 $\alpha$ 是学习率 , $\lambda$ 是 weight decay 的系数. 如果对损失函数施加 L2 Regularization :&lt;/p&gt;
&lt;p&gt;$$
f_t^{reg}(\theta) =   f_t(\theta) + \frac {\lambda &apos;} {2} {|\theta|_2}^2
$$&lt;/p&gt;
&lt;p&gt;使用梯度下降:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\theta_{t+1} &amp;amp;=   \theta_{t}  - \alpha \nabla f_t^{reg}(\theta_{t}) \
&amp;amp;=   \theta_{t}  - \alpha \nabla f_t(\theta_{t}) - \alpha \lambda &apos; \theta_{t}\
&amp;amp;= (1 - \alpha \lambda &apos; )  \theta_{t}  - \alpha \nabla f_t(\theta_{t})
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;如果想让 weight decay 和 带L2 Regularization 等价 , 则应有
$\alpha \lambda&apos; = \lambda$
, 显然对于SGD我们可以做到这个事情. 也就是说 &lt;strong&gt;在SGD优化器下, weight decay 和 带L2 Regularization 等价.&lt;/strong&gt; 不过有个问题, 假设我们存在一个最优的weight decay系数 $\lambda$ , 并且置了 L2 的系数
$\lambda&apos;$
, 这样就会把系统的学习率给固定了. 换句话说, 这时 weight decay 的系数 和 L2 Regularization 的系数是耦合的. 二者会相互影响.&lt;/p&gt;
&lt;h3&gt;1.2 添加 momentum&lt;/h3&gt;
&lt;p&gt;如果在 L2 Regularization 的基础上添加 momentum 项&lt;/p&gt;
&lt;p&gt;$$
g_t = \nabla f_{t-1}(\theta_{t-1}) + \lambda &apos; \theta_{t-1}
$$&lt;/p&gt;
&lt;p&gt;$$
m_t = \beta_{1}m_{t-1} + g_t
$$&lt;/p&gt;
&lt;p&gt;SGD with momentum and weight decay (L2 Regularization) 式子将会变为:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
\theta_{t} &amp;amp;=   \theta_{t-1}  - \alpha  m_t \
&amp;amp;=   \theta_{t-1}  -  \alpha (\beta_{1}m_{t-1} -  \nabla f_{t-1}(\theta_{t-1}) - \lambda &apos; \theta_{t-1}) \
&amp;amp;= \underbrace{(1 - \alpha \lambda &apos; )  \theta_{t-1}}&lt;em&gt;{weight \ decay}  - \underbrace{\alpha \nabla f&lt;/em&gt;{t-1}(\theta_{t-1})}&lt;em&gt;{gradient \ descent} -  \underbrace{\alpha \beta&lt;/em&gt;{1}m_{t-1}}_{momentum}
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;这里, 学习率 $\alpha$ 和 L2 Regularization 的系数还是耦合, 并且还和 momentum 的系数也耦合上了.&lt;/p&gt;
&lt;p&gt;:::warning
耦合归耦合, 但是该说不说, 在SGD场景下, Weight Decay == L2 Regularization 是可以成立的. 无论加不加 momentum
:::&lt;/p&gt;
&lt;h2&gt;2. Adam场景下&lt;/h2&gt;
&lt;p&gt;这里就不敲公式了,给出 &lt;a href=&quot;https://arxiv.org/abs/1711.05101&quot;&gt;AdamW paper&lt;/a&gt; 附录的证明.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/21/afDMybYdESVpQoB.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;我们知道, 在 Adam 优化器中, 学习率是自适应变化的, 上图中 $M_t$ 就表示给学习率乘的自适应系数矩阵. 要想&lt;/p&gt;
&lt;p&gt;$$
\lambda \theta_{t}  = \alpha \lambda &apos; M_t \theta_{t}
$$&lt;/p&gt;
&lt;p&gt;就必须让&lt;/p&gt;
&lt;p&gt;$$
\lambda   = \alpha \lambda &apos; M_t
$$&lt;/p&gt;
&lt;p&gt;其中 $\lambda \ , \alpha \ ,\lambda&apos; $ 三兄弟都是常数, $M_t$  又是自适应系数, 显然是不能实现上边的目标的,&lt;/p&gt;
&lt;p&gt;:::warning
因此对于类似 Adam 这种自适应学习率的算法,  Weight Decay $\neq$ L2 Regularization . 无论加不加 momentum
:::&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>Kullback–Leibler divergence</title><link>https://xuchenhui.cc/posts/2024-04-19-kullback-leibler-divergence/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-19-kullback-leibler-divergence/</guid><description>介绍 KL 散度（Kullback-Leibler divergence）的定义、离散与连续版本的公式，以及其在衡量两个概率分布差异性中的核心作用。</description><pubDate>Fri, 19 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;我们有2个分布 $P$ 和 $Q$, 如何比较二者之间的差异性? 在数理统计上, K-L 散度是一个常用的方法.&lt;/p&gt;
&lt;h2&gt;1. 定义&lt;/h2&gt;
&lt;h3&gt;1.1 离散版本&lt;/h3&gt;
&lt;p&gt;For discrete probability distributions $P$ and $Q$ defined on the same sample space $\mathcal {X}$ .&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(P \ ||\  Q)  = \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {Q(x)})
$$&lt;/p&gt;
&lt;p&gt;which is equivalent to&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(P \ ||\  Q)  = - \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{Q(x)} {P(x)})
$$&lt;/p&gt;
&lt;h3&gt;1.2 连续版本&lt;/h3&gt;
&lt;p&gt;$$
D_{KL}(P \ ||\  Q)  = \int_{x \in \mathcal {X}} p(x) \ log(\frac{p(x)} {q(x)}) \ dx
$$&lt;/p&gt;
&lt;h2&gt;2. 理解&lt;/h2&gt;
&lt;h3&gt;2.1 公式上&lt;/h3&gt;
&lt;p&gt;有人会把KL散度理解为一种&quot;距离&quot;,不过&quot;距离&quot;需要满足以下几个性质&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;非负性 : 满足&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;证明&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
D_{KL}(P \ ||\  Q)  &amp;amp;= - \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{Q(x)} {P(x)}) \
&amp;amp;&amp;gt;= - \ \sum_{x \in \mathcal {X}} log(P(x) \  * \ \frac{Q(x)} {P(x)})  \  (凸函数:E(f(x)) &amp;gt;= f(E(x)))\
&amp;amp;= - \ \sum_{x \in \mathcal {X}} log(Q(x)) \
&amp;amp;&amp;gt;= - \ log (\sum_{x \in \mathcal {X}}  Q(x) ) \  (凸函数:Jensen不等式)\
&amp;amp;= 0
\end{align*}
$$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;同一性 : 满足&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;证明&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
D_{KL}(P \ ||\  P)  =  \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {P(x)}) = 0
$$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;对称性 : 不满足&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
\begin{align*}
D_{KL}(P \ ||\  Q)  &amp;amp;=  \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {Q(x)}) \
&amp;amp;\neq\
D_{KL}(Q \ ||\  P)  &amp;amp;=  \ \sum_{x \in \mathcal {X}} Q(x) \ log(\frac{Q(x)} {P(x)}) \
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;一般来说不等, 所以对称性不满足&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;三角不等式 : 不满足&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;假设有 &lt;code&gt;P Q R&lt;/code&gt; 三个分布, 探究
$D_{KL}(P \ ||\  R)$ 与 $D_{KL}(P \ ||\  Q)$ 、$D_{KL}(Q \ ||\  R)$ 的关系。&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
D_{KL}(P \ ||\  R)  &amp;amp;=  \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {R(x)}) \
&amp;amp;=  \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {R(x)} * \frac{Q(x)} {Q(x)} ) \
&amp;amp;=  \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{P(x)} {Q(x)} * \frac{Q(x)} {R(x)} ) \
&amp;amp;=  D_{KL}(P \ ||\  Q) + \ \sum_{x \in \mathcal {X}} P(x) \ log(\frac{Q(x)} {R(x)} ) \
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;那就需要看后者&lt;/p&gt;
&lt;p&gt;$$
\sum_{x \in \mathcal {X}} P(x) \ log(\frac{Q(x)} {R(x)} )
$$&lt;/p&gt;
&lt;p&gt;与&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(Q \ ||\  R)
$$&lt;/p&gt;
&lt;p&gt;之间的大小关系, 但是很遗憾, 二者大小无法判定. 因此有可能出现以下情况, 所以三角不等式不满足.&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(P \ ||\  R) &amp;gt; D_{KL}(P \ ||\  Q) + D_{KL}(Q \ ||\  R)
$$&lt;/p&gt;
&lt;p&gt;:::note
因此称其为“距离”是不合适的， 充其量只能说其可以度量两个分布之间的差异性。
:::&lt;/p&gt;
&lt;h3&gt;2.2 从熵的角度&lt;/h3&gt;
&lt;p&gt;&quot;熵&quot;通常指的是&lt;a href=&quot;https://zh.wikipedia.org/zh-hans/%E7%86%B5_(%E4%BF%A1%E6%81%AF%E8%AE%BA)&quot;&gt;香农熵(Shannon entropy)&lt;/a&gt;. 原来是信息论里边的东西, 其公式大家很熟悉:&lt;/p&gt;
&lt;p&gt;$$
H(X) = \sum_{x \in \mathcal {X}} P(x) \ log\ \frac{1} {P(x)}
$$&lt;/p&gt;
&lt;p&gt;:::note
如果你不知道这个公式为什么是这样, 而不是那样, 可以看我在 B站 发的视频 : &lt;a href=&quot;https://www.bilibili.com/video/BV1kg411u7RP&quot;&gt;https://www.bilibili.com/video/BV1kg411u7RP&lt;/a&gt;
:::&lt;/p&gt;
&lt;p&gt;然后来看以下推导:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
I(X,Y) &amp;amp;= D_{KL}(p(x,y) \ ||\  p(x)\ \times \ p(y))  (定义)\
&amp;amp;=  \ \sum_{x , y} p_{xy} \ log(\frac{p_{xy}} {p_x \times p_y})  \
&amp;amp; = \ \sum_{x , y} p_{xy} \ log(\frac{p_{xy}} {p_x}) - \ \sum_{x , y} p_{xy} \ log \ p_{y} \
&amp;amp; = \ \sum_{x , y} p_{xy} \ log \ p(y|x) - \ \sum_{y}  \ (\sum_{x} p_{xy}) \ log \ p_{y} \
&amp;amp; = \ \sum_{x , y} p_{x}p_{y|x} \ log \ p(y|x) - \ \sum_{y}  \ p_{y} \ log \ p_{y} \
&amp;amp; = \ \sum_{x} p_{x} \ \sum_{y } p_{y|x} \ log \ p(y|x) - \ \sum_{y}  \ p_{y} \ log \ p_{y} \
&amp;amp; = - \ \sum_{x } p_{x} \ H(y|X=x) - \ \sum_{y}  \ p_{y} \ log \ p_{y} \
&amp;amp; = -H(Y|X) +  H(Y) \
&amp;amp; = H(Y) - H(Y|X)  \
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;其中, $I(X,Y)$ 称为随机变量 X 与 Y 的互信息量,
$H(Y\ |\ X)$
表示在已知随机变量X的情况下,Y的熵,也称条件熵.&lt;/p&gt;
&lt;p&gt;假设 $P = p(x,y), Q =  p(x) * p(y)$
我们从 $I(X,Y) = D_{KL}(P\ || \ Q) = H(Y) - H(Y|X)$
的角度来看, 当
$I(X,Y) = 0$
就是想说
$H(Y) - H(Y|X) = 0$&lt;/p&gt;
&lt;p&gt;:::note
表明已知信息 X , 仍然有 $H(Y|X) = H(Y)$, 即 X 对 Y 的熵降低无任何作用 &amp;lt;=&amp;gt; X 和 Y 独立 &amp;lt;=&amp;gt; P = Q
:::&lt;/p&gt;
&lt;p&gt;这样看来, $minimize \ KL(P,Q)$ 就是想让二者从&lt;code&gt;信息量上&lt;/code&gt;尽量的相近&lt;/p&gt;
&lt;h2&gt;3. 应用&lt;/h2&gt;
&lt;p&gt;$$
\begin{align*}
D_{KL}(P \ ||\  Q)  &amp;amp;= \sum_{x \in \mathcal {X}} P(x) \ log \ \frac{P(x)} {Q(x)} \
&amp;amp;= \sum_{x \in \mathcal {X}} P(x) \ log \ P(x) - \sum_{x \in \mathcal {X}} P(x) \ log \ Q(x) \
&amp;amp;=  - H(P) + \sum_{x \in \mathcal {X}} P(x) \ log \ \frac{1}{Q(x)} \
&amp;amp;=  - H(P) + H(P,Q)
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;这里 H(P,Q) 称为 分布 P 和 分布 Q 的交叉熵, 通常我们在机器学习或者深度学习中, 可以把 P 分布理解为真实的概率分布(未知,但是固定) , 因此 - H(P) 就是个常数 ; Q 为我们模型输出的概率分布, 所以可以通过
$minimize \ H(P,Q)$
去等价
$minimize \ D_{KL}(P \ ||\  Q) $
. 即 交叉熵 损失函数.
:::note
:::&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;4. 补充&lt;/h2&gt;
&lt;p&gt;KL divergence between two multivariate Gaussian distributions.&lt;/p&gt;
&lt;p&gt;Probabilty density function of multivariate Normal distribution is given by:&lt;/p&gt;
&lt;p&gt;$$
p(\mathbf{x}) = \frac{1}{(2\pi)^{k/2}|\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(\mathbf{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\mathbf{x}-\boldsymbol{\mu})\right)
$$&lt;/p&gt;
&lt;p&gt;假设2个分布分别为 $\mathcal{N}(\boldsymbol{\mu_p},,\Sigma_p) $ 和 $\mathcal{N}(\boldsymbol{\mu_q},,\Sigma_q)$ , 其中 $\mu$ 为 $k$ 维 列向量.&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
D_{KL}(p||q) &amp;amp; = \mathbb{E}_p\left[\log(p) - \log(q)\right]
\newline
&amp;amp; = \mathbb{E}_p\left[\frac{1}{2}\log\frac{|\Sigma_q|}{|\Sigma_p|} - \frac{1}{2}(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}(\mathbf{x}-\boldsymbol{\mu_p}) + \frac{1}{2}(\mathbf{x}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\mathbf{x}-\boldsymbol{\mu_q})\right]
\newline
&amp;amp; = \frac{1}{2}\mathbb{E}_p\left[\log\frac{|\Sigma_q|}{|\Sigma_p|}\right] - \frac{1}{2}\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}(\mathbf{x}-\boldsymbol{\mu_p})\right] + \frac{1}{2}\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\mathbf{x}-\boldsymbol{\mu_q})\right]
\newline
&amp;amp; = \frac{1}{2}\log\frac{|\Sigma_q|}{|\Sigma_p|} - \frac{1}{2}\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}(\mathbf{x}-\boldsymbol{\mu_p})\right] + \frac{1}{2}\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\mathbf{x}-\boldsymbol{\mu_q})\right]
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;其中
$(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}(\mathbf{x}-\boldsymbol{\mu_p})$
是一个实数. 所以可以重新写为 :&lt;/p&gt;
&lt;p&gt;$$
tr \left{(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}(\mathbf{x}-\boldsymbol{\mu_p})\right}
$$&lt;/p&gt;
&lt;p&gt;其中 tr{} 表示  trace operator , 利用 trace trick (轮换性) , 可以将上式修改为:&lt;/p&gt;
&lt;p&gt;$$
tr \left{(\mathbf{x}-\boldsymbol{\mu_p})(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}\right}
$$&lt;/p&gt;
&lt;p&gt;于是第2项可修改为:&lt;/p&gt;
&lt;p&gt;$$
\frac{1}{2}\mathbb{E}_p\left[tr\left{(\mathbf{x}-\boldsymbol{\mu_p})(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}\right}\right]
$$&lt;/p&gt;
&lt;p&gt;然后, 将 expectation 和 trace 交换位置, 且 $\Sigma_p^{-1}$ 是常数矩阵:&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
&amp;amp; = \frac{1}{2}tr\left{\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_p})(\mathbf{x}-\boldsymbol{\mu_p})^T\Sigma_p^{-1}\right]\right}
\newline
&amp;amp; = \frac{1}{2}tr\left{\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_p})(\mathbf{x}-\boldsymbol{\mu_p})^T\right]\Sigma_p^{-1}\right}
\newline
&amp;amp; = \frac{1}{2}tr\left{\Sigma_p\Sigma_p^{-1}\right}
\newline
&amp;amp; = \frac{1}{2}tr\left{I_k\right}
\newline
&amp;amp; = \frac{k}{2}
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;而第3项(证明在最后) :&lt;/p&gt;
&lt;p&gt;$$
\mathbb{E}_p\left[(\mathbf{x}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\mathbf{x}-\boldsymbol{\mu_q})\right] = (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left{\Sigma_q^{-1}\Sigma_p\right}
$$&lt;/p&gt;
&lt;p&gt;于是:&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left{\Sigma_q^{-1}\Sigma_p\right}\right]
$$&lt;/p&gt;
&lt;p&gt;当 $q \sim \mathcal{N}(0,,I)$ :&lt;/p&gt;
&lt;p&gt;$$
D_{KL}(p||q) = \frac{1}{2}\left[\boldsymbol{\mu_p}^T\boldsymbol{\mu_p} + tr\left{\Sigma_p\right} - k - \log|\Sigma_p|\right]
$$&lt;/p&gt;
&lt;hr /&gt;
&lt;p&gt;关于第3项的证明:&lt;/p&gt;
&lt;p&gt;$$
\begin{equation}\begin{aligned}
\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[(\boldsymbol{x}-\boldsymbol{\mu}_q)^{\top}\boldsymbol{\Sigma}_q^{-1}(\boldsymbol{x}-\boldsymbol{\mu}&lt;em&gt;q)\right]=&amp;amp;,\mathbb{E}&lt;/em&gt;{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\text{Tr}\left((\boldsymbol{x}-\boldsymbol{\mu}_q)^{\top}\boldsymbol{\Sigma}_q^{-1}(\boldsymbol{x}-\boldsymbol{\mu}&lt;em&gt;q)\right)\right]\
=&amp;amp;,\mathbb{E}&lt;/em&gt;{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\text{Tr}\left(\boldsymbol{\Sigma}_q^{-1}(\boldsymbol{x}-\boldsymbol{\mu}_q)(\boldsymbol{x}-\boldsymbol{\mu}_q)^{\top}\right)\right]\
=&amp;amp;,\text{Tr}\left(\boldsymbol{\Sigma}&lt;em&gt;q^{-1}\mathbb{E}&lt;/em&gt;{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[(\boldsymbol{x}-\boldsymbol{\mu}_q)(\boldsymbol{x}-\boldsymbol{\mu}_q)^{\top}\right]\right)\
=&amp;amp;,\text{Tr}\left(\boldsymbol{\Sigma}&lt;em&gt;q^{-1}\mathbb{E}&lt;/em&gt;{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\boldsymbol{x}\boldsymbol{x}^{\top}-\boldsymbol{\mu}_q\boldsymbol{x}^{\top} - \boldsymbol{x}\boldsymbol{\mu}_q^{\top} +  \boldsymbol{\mu}_q\boldsymbol{\mu}_q^{\top}\right]\right)\
=&amp;amp;,\text{Tr}\left(\boldsymbol{\Sigma}_q^{-1}\left(\boldsymbol{\Sigma}_p + \boldsymbol{\mu}_p\boldsymbol{\mu}_p^{\top}-\boldsymbol{\mu}_q\boldsymbol{\mu}_p^{\top} - \boldsymbol{\mu}_p\boldsymbol{\mu}_q^{\top} +  \boldsymbol{\mu}_q\boldsymbol{\mu}_q^{\top}\right)\right)\
=&amp;amp;,\text{Tr}\left(\boldsymbol{\Sigma}_q^{-1}\boldsymbol{\Sigma}_p + \boldsymbol{\Sigma}_q^{-1}(\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)(\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)^{\top}\right)\
=&amp;amp;,\text{Tr}\left(\boldsymbol{\Sigma}_q^{-1}\boldsymbol{\Sigma}_p\right) + (\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)^{\top}\boldsymbol{\Sigma}_q^{-1}(\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)\
\end{aligned}\end{equation}
$$&lt;/p&gt;
&lt;p&gt;注意到当 $\boldsymbol{\mu}_q=\boldsymbol{\mu}_p,\boldsymbol{\Sigma}_q=\boldsymbol{\Sigma}_p$, 上式就是 $n$ , 对应正态分布的熵.&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence&quot;&gt;https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://zh.wikipedia.org/zh-hans/%E4%BA%92%E4%BF%A1%E6%81%AF&quot;&gt;https://zh.wikipedia.org/zh-hans/%E4%BA%92%E4%BF%A1%E6%81%AF&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[3] &lt;a href=&quot;https://en.wikipedia.org/wiki/Entropy_(information_theory)&quot;&gt;https://en.wikipedia.org/wiki/Entropy_(information_theory)&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[4] &lt;a href=&quot;https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/&quot;&gt;https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[5] &lt;a href=&quot;https://kexue.fm/archives/8512&quot;&gt;https://kexue.fm/archives/8512&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>Text Generator With Transformer Decoder</title><link>https://xuchenhui.cc/posts/2024-04-13-text-generator-with-transformer-decoder/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-13-text-generator-with-transformer-decoder/</guid><description>利用 Transformer Decoder 从零实现一个简单的文本生成器，涵盖数据构造、Mask Attention、Positional Encoding 等核心组件的代码实现。</description><pubDate>Sat, 13 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;这篇 Blog 主要聚焦于利用 Transformer 的 Decoder &lt;strong&gt;实现&lt;/strong&gt;一个简单的 text generator. 虽然代码相对简单, 但是核心思想类似, 做个记录, 方便后续学习理解. 主要参考 : &lt;a href=&quot;https://wingedsheep.com/building-a-language-model/&quot;&gt;https://wingedsheep.com/building-a-language-model/&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;:::note
阅读前, 需要你 : 了解 Transformer结构, Attention机制,  Mask Attention机制, 了解 Pytorch 以及 NLP 相关的基础知识, 比如 token, embedding, 序列之类的.
:::&lt;/p&gt;
&lt;h2&gt;1. 准备&lt;/h2&gt;
&lt;p&gt;首先需要搞明白真正在编程实现一个 &lt;code&gt;text generator&lt;/code&gt; 的时候, 代码核心是什么? 我们来列出任务的基本组成:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;数据构造
&lt;ul&gt;
&lt;li&gt;输入 : 一个序列&lt;/li&gt;
&lt;li&gt;输出 : 一个单词 或者 一个序列&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;模型实现
&lt;ul&gt;
&lt;li&gt;positional encoding&lt;/li&gt;
&lt;li&gt;token embedding&lt;/li&gt;
&lt;li&gt;decoder layer
&lt;ul&gt;
&lt;li&gt;mask attention (multi-head)&lt;/li&gt;
&lt;li&gt;mlp, layer normalization, active function...等常规组件&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;文本生成
&lt;ul&gt;
&lt;li&gt;让输出持续下去!&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;其中&lt;/p&gt;
&lt;p&gt;[1] &lt;code&gt;数据构造&lt;/code&gt; 步骤是需要针对我们的任务单独构造实现 (核心)&lt;/p&gt;
&lt;p&gt;[2] &lt;code&gt;mask attention&lt;/code&gt; 是 decoder 结构的核心, 这个没的说&lt;/p&gt;
&lt;p&gt;[3] 目标 : &lt;code&gt;文本生成&lt;/code&gt; , 核心没的说.&lt;/p&gt;
&lt;p&gt;其他的, positional encoding, token embedding 以及 mlp 等组件都是常规操作, 不足为虑. 因此我们将从上述3个方面来入手.&lt;/p&gt;
&lt;h2&gt;2. 数据构造&lt;/h2&gt;
&lt;h3&gt;2.1 让计算机处理文字&lt;/h3&gt;
&lt;p&gt;文本生成任务是说, 我们想让模型根据一个简单提示词, 然后接着提示词不断的写下去. 比如, 给模型输入: &quot;我爱您,&quot;, 那么模型也许能够输出: &quot;母亲, 感谢您的养育之恩.&quot;  最后我们将输入和输出连起来得到完整的语句: &quot;我爱您, 母亲, 感谢您的养育之恩.&quot;&lt;/p&gt;
&lt;p&gt;不过计算机没法处理汉字, 英文也不认识. 所以我们首先需要把英文啊, 中文转成数字.&lt;/p&gt;
&lt;p&gt;怎么转呢? 其实非常简单, 假设我们汉字有10w个, 我们就把每个汉字和一个数字一一对应即可. 比如 :&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/14/6nshFYQyPIBq8N1.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;注意这里, 标点符号也是要进行转换. 好的, 现在我们可以把&lt;code&gt;汉字&lt;/code&gt;或者&lt;code&gt;单词&lt;/code&gt;输入到模型中了.&lt;/p&gt;
&lt;p&gt;不过汉字太多了, 为方便叙述, 后续我们使用英文来举例子. 本文我们就用到的 &lt;code&gt;a-z&lt;/code&gt; 26个字母 + &lt;code&gt;0-9&lt;/code&gt; 10个数字 + &lt;code&gt;&apos;\ &apos;&lt;/code&gt; + &lt;code&gt;&apos;,&apos;&lt;/code&gt; + &lt;code&gt;&apos;.&apos;&lt;/code&gt; + &lt;code&gt;&apos;&amp;lt;pad&amp;gt;&apos;&lt;/code&gt;共40个字符, 我们称之为 &apos;vocabulary&apos;. 其中, &lt;code&gt;&apos;&amp;lt;pad&amp;gt;&apos;&lt;/code&gt; 用于对句子进行填充, 使得训练的时候, 输入的句子一样长. 以下代码实现将这些字符映射到数字.&lt;/p&gt;
&lt;p&gt;&amp;lt;details markdown=&quot;1&quot;&amp;gt;
&amp;lt;summary&amp;gt; 详细信息 &amp;lt;/summary&amp;gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
class Tokenizer:
    r&apos;&apos;&apos;
        0-9 (10 个 token) , a-z (26 个 token) , &apos; &apos; , &apos;,&apos; &apos;.&apos; ,  &apos;&amp;lt;pad&amp;gt;&apos; 共40个token
    &apos;&apos;&apos;
    def __init__(self):
        self.dictionary = {}
        self.reverse_dictionary = {}

        # Add the padding token
        self.__add_to_dict(&apos;&amp;lt;pad&amp;gt;&apos;)

        # Add characters and numbers to the dictionary
        for i in range(10):
            self.__add_to_dict(str(i))
        for i in range(26):
            self.__add_to_dict(chr(ord(&apos;a&apos;) + i))

        # Add space and punctuation to the dictionary
        self.__add_to_dict(&apos;,&apos;)
        self.__add_to_dict(&apos;.&apos;)
        self.__add_to_dict(&apos; &apos;)

    def __add_to_dict(self, character):
        if character not in self.dictionary:
            self.dictionary[character] = len(self.dictionary)
            self.reverse_dictionary[self.dictionary[character]] = character

    def tokenize(self, text):
        return [self.dictionary[c] for c in text]

    def character_to_token(self, character):
        return self.dictionary[character]

    def token_to_character(self, token):
        return self.reverse_dictionary[token]

    def size(self):
        return len(self.dictionary)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&amp;lt;/details&amp;gt;&lt;/p&gt;
&lt;h3&gt;2.2 输入输出构造&lt;/h3&gt;
&lt;p&gt;我们的训练集是一句话 : &quot;cats rule the world. dogs are the best. elephants have long trunks. monkeys like bananas. pandas eat bamboo. tigers are dangerous. zebras have stripes. lions are the kings of the savannah. giraffes have long necks. hippos are big and scary. rhinos have horns. penguins live in the arctic. polar bears are white&quot;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;以下列出了几组输入和输出样例 (假设每句话最大token数目限制为3) :&lt;/p&gt;
&lt;p&gt;[1] &lt;code&gt;&apos;cat&apos; -&amp;gt; &apos;ats&apos;&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;code&gt;&apos;ats&apos; -&amp;gt; &apos;ts &apos;&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;[3] &lt;code&gt;&apos;ts &apos; -&amp;gt; &apos;s r&apos;&lt;/code&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;使用如下代码将字符转换为数字.&lt;/p&gt;
&lt;p&gt;&amp;lt;details markdown=&quot;1&quot;&amp;gt;
&amp;lt;summary&amp;gt; 详细信息 &amp;lt;/summary&amp;gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;        # Create the training data
        training_data = &apos;. &apos;.join([
            &apos;cats rule the world&apos;,
            &apos;dogs are the best&apos;,
            &apos;elephants have long trunks&apos;,
            &apos;monkeys like bananas&apos;,
            &apos;pandas eat bamboo&apos;,
            &apos;tigers are dangerous&apos;,
            &apos;zebras have stripes&apos;,
            &apos;lions are the kings of the savannah&apos;,
            &apos;giraffes have long necks&apos;,
            &apos;hippos are big and scary&apos;,
            &apos;rhinos have horns&apos;,
            &apos;penguins live in the arctic&apos;,
            &apos;polar bears are white&apos;
        ])

        tokenized_and_padded_training_data = tokenize_and_pad_training_data(max_sequence_length, tokenizer, training_data)
        def tokenize_and_pad_training_data(max_sequence_length, tokenizer, training_data):
            # Tokenize the training data
            tokenized_training_data = tokenizer.tokenize(training_data)
            for _ in range(max_sequence_length):
                # Prepend padding tokens
                tokenized_training_data.insert(0, tokenizer.character_to_token(&apos;&amp;lt;pad&amp;gt;&apos;))
            return tokenized_training_data


&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&amp;lt;/details&amp;gt;&lt;/p&gt;
&lt;p&gt;字符转数字处理后的结果 &lt;code&gt;tokenized_and_padded_training_data&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/14/RNGpnr92LbqPW6U.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;ok , 经过上边的映射, &lt;code&gt;a&lt;/code&gt;就是 11 , &lt;code&gt;b&lt;/code&gt; 就是 12, 假设我们想输入 &lt;code&gt;abc&lt;/code&gt;, 希望模型预测的输出是 &lt;code&gt;d&lt;/code&gt; . 那输入就是 &lt;code&gt;11 12 13&lt;/code&gt;,  输出就是 &lt;code&gt;14&lt;/code&gt;, 即 &lt;code&gt;11 12 13 -&amp;gt; 14&lt;/code&gt;.&lt;/p&gt;
&lt;p&gt;本篇 blog 使用 &lt;code&gt;11 12 13 -&amp;gt; 12 13 14&lt;/code&gt; 的输出格式, 本质是一样的. 这样保证了输入输出的序列长度是一样的. 此外代码中实现的时候, 我们假设序列的长度为 $20$ . 一句话不足 $20$ 个 token, 用&lt;code&gt;&amp;lt;pad&amp;gt;&lt;/code&gt;字符填充(就是用0填充). 举个例子:&lt;/p&gt;
&lt;p&gt;用  &lt;code&gt;1 2 3 -&amp;gt; 2 3 4&lt;/code&gt; , 经过填充后, 最后实际给模型的输入输出为 &lt;code&gt;0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2 3 -&amp;gt; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2 3 4 &lt;/code&gt;&lt;/p&gt;
&lt;p&gt;你可能注意到, 我们在整个句子前边添加了 20 个&lt;code&gt;&amp;lt;pad&amp;gt;&lt;/code&gt;字符. 这是为了下一步对输入输出构造, 对于第一组输入和输出.&lt;/p&gt;
&lt;p&gt;输入 &apos;空白符&apos; , 输出&apos;c&apos;:&lt;/p&gt;
&lt;p&gt;&lt;code&gt;0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 &lt;/code&gt; -&amp;gt; &lt;code&gt;0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 13 &lt;/code&gt;&lt;/p&gt;
&lt;p&gt;其他组的输入和输出, 其构造过程就是一个滑动窗口:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/14/LJzI78GgbM5uYdP.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;构造过程代码:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;        # ...
        sequences = create_training_sequences(max_sequence_length, tokenized_and_padded_training_data)
        def create_training_sequences(max_sequence_length, tokenized_training_data):
            sequences = []
            for i in range(0, len(tokenized_training_data) - max_sequence_length - 1):
                sequences.append(tokenized_training_data[i: i + max_sequence_length + 1])
            return sequences

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最后每一组的输入和输出构造结果&lt;code&gt;sequences&lt;/code&gt;如下:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;这里为了表示方便, 我们把输入和输出放到一个list里边, 因为输入和输出就差一个字符&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/Wb1vRMw7xJ4LBPq.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h2&gt;3. Mask Attention&lt;/h2&gt;
&lt;p&gt;假设输入是 &apos;cat&apos;, 经过填充就是 &lt;code&gt;0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 11, 30&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;假设这个序列的Attention矩阵如下:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/jJybvtZuLKr8QC9.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;Attention不说了, 这里 Mask 主要有2个缘由:&lt;/p&gt;
&lt;p&gt;[1] 由于我们的输入数据中含有填充字符(这个过程叫padding), 而这些填充字符是没有实际意义的。因此，在进行注意力计算时，我们希望有效的token不会与这些填充字符attention, 因此需要对attention权重矩阵应用padding mask&lt;/p&gt;
&lt;p&gt;[2] 此外，由于我们的任务是文本生成，字符是按顺序一个接一个地生成的，从左到右逐步产生。因此，我们希望当前的token只能注意到其左侧的token, 不允许其注意到未来的token.&lt;/p&gt;
&lt;h3&gt;3.1 Padding Mask&lt;/h3&gt;
&lt;p&gt;Padding Mask 中, 1 表示当前输入允许注意的位置. 举个例子:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/YQ8PRlKq2cTi4A7.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;当然如果没有 padding 字符, 那自然 padding mask 就全是 1 了.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
注意这里, Padding Mask 不负责 &quot;当前 token 不允许与右侧 token 做attention&quot;, 这部分由 Causal Mask 负责.
:::&lt;/p&gt;
&lt;p&gt;这样经过 Padding Mask 后, Attention 矩阵应该如下:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/OZTfaF13c2hYz6u.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h3&gt;3.2 Causal Mask&lt;/h3&gt;
&lt;p&gt;Causal Mask 就常规了, 一个 shape with  $序列长度 \times 序列长度$ 的下三角矩阵.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/qTOa2cKxWPeizgF.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h3&gt;3.3 武魂融合, 启动!&lt;/h3&gt;
&lt;p&gt;那实际中我们是 &quot;两手都要抓,两手都要硬&quot;, 即需要 &quot;Padding Mask&quot; 也需要 &quot;Causal Mask&quot;.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/ZP7sNoK2UcOg3Gj.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;最后, 我们可以看到, 只有 &quot;2&quot; 的位置 Attention score 是允许的, 其余位置都不可以.&lt;/p&gt;
&lt;h2&gt;4. 文本生成&lt;/h2&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/15/kqFK1vbJdLS37Yz.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;这里最后我们模型输出的是下一个 token 在 vocabulary 上的概率, 因此具体下个 token 具体是什么, 需要采样, 采样思路有很多, 可以参考: &lt;a href=&quot;https://huggingface.co/blog/how-to-generate&quot;&gt;how-to-generate&lt;/a&gt;. 不过这篇 Blog 就简单的输出概率最大的那个token.&lt;/p&gt;
&lt;h2&gt;5. 完整代码&lt;/h2&gt;
&lt;p&gt;完整代码见: &lt;a href=&quot;https://github.com/CHENHUI-X/Text-Generator-With-Decoder/tree/main&quot;&gt;Text-Generator-With-Decoder&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://wingedsheep.com/building-a-language-model/&quot;&gt;https://wingedsheep.com/building-a-language-model/&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>AUC &amp; GAUC</title><link>https://xuchenhui.cc/posts/2024-04-12-auc-gauc/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-12-auc-gauc/</guid><description>深入讲解 AUC 和 GAUC 的概念与计算方法，以及 AUC 作为衡量正负样本排序能力的统计含义，从 ROC 曲线到实际应用。</description><pubDate>Fri, 12 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;AUC 经常被用来评估一个机器学习模型的综合性能,  我们通常听到的版本,  AUC 指的是 ROC 曲线下的面积,  不过在实际中他是如何计算的? GAUC 又是什么?  此外,  AUC 还有另外一种含义,  描述的是任意取一对儿正负样本,  模型能够把 &quot;正样本&quot;  排序到 &quot;负样本&quot; 前边的能力. 这又是什么?&lt;/p&gt;
&lt;h2&gt;1. 基本知识&lt;/h2&gt;
&lt;p&gt;说 AUC 不得不说 &lt;a href=&quot;https://en.wikipedia.org/wiki/Receiver_operating_characteristic&quot;&gt;ROC (Receiver operating characteristic) &lt;/a&gt;曲线,  说 ROC 曲线又不得不说 &lt;a href=&quot;https://en.wikipedia.org/wiki/Confusion_matrix&quot;&gt;混淆矩阵 &lt;/a&gt;&lt;/p&gt;
&lt;p&gt;混淆矩阵用来可视化模型的分类结果,  帮助我们清晰的看到&lt;strong&gt;模型在给定某个阈值下&lt;/strong&gt;对各个类别的覆盖能力如何. 这里简单放个图&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/12/19ZvCgjRszOKPnw.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;绘制ROC曲线主要会用到 2 个指标:&lt;/p&gt;
&lt;p&gt;&lt;code&gt;True positive rate (TPR)&lt;/code&gt;: TPR (真正率),  也叫 Recall (召回),  Sensitivity (灵敏度). 它描述的是本身就是正样本,  模型也预测为正样本占所有正样本的比例.&lt;/p&gt;
&lt;p&gt;$$
TPR = \frac {TP}  {P}  = \frac {TP}  {TP + FN}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;很明显,  阈值越小,  TPR 越大,  当阈值为 0,  所有的样本全部预测为正类,  那么 $TPR = 1$ . 反之,  当阈值升高,  TPR 下降,  当阈值为 1,  所有的样本全部预测为负类 ,  此时 $TPR = 0$&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;code&gt;False positive rate (TPR)&lt;/code&gt;: FPR (假正率). 它描述的是本身就是负样本,  却被模型预测为正样本占所有负样本的比例.&lt;/p&gt;
&lt;p&gt;$$
FPR = \frac {FP}  {N}   = \frac {FP}  {FP + TN}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;很明显,  阈值越小,  FPR 越大,  当阈值为 0,  所有的样本全部预测为正类,  那么 $FPR = 1$ . 反之,  当阈值升高,  FPR 下降,  当阈值为 1,  所有的样本全部预测为负类 ,  此时 $FPR = 0$ .&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote&gt;
&lt;p&gt;可以看到 TPR 和 FPR 的趋势一致. 那么给定一个阈值,  就得到一对儿对应的 TPR 和 FPR ,  我们令阈值 从 0 - 1 ,  这样就会有很多对儿 TPR 和 FPR . 将其以 FPR 作为横轴,  TPR作为纵轴,  就得到了 ROC曲线.
:::tip
:::
ROC曲线如图 :&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/12/EZLk7pKxzshoqej.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;[1] ROC 曲线越靠左上方越好,  这表明在给定的阈值下,  $TPR &amp;gt; FPR$,  从含义上讲,  就是模型 预测正确的能力(TPR) 比 预测错误的能力(FPR) 要更强.&lt;/p&gt;
&lt;p&gt;[2] 中间那条红色的虚线,  表示随机猜测,  此时无论什么阈值,  $ TPR == FPR$ ,  就是模型 预测正确的能力(TPR) == 预测错误的能力(FPR),  换句话说这个模型没有任何预估能力. 一眼丁真,  鉴定为就是在抛硬币.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;2. AUC 的含义&lt;/h2&gt;
&lt;p&gt;上边说道,  ROC 曲线越靠左上方越好,  但是这个可能比较主观,  我们需要用一个定量的指标来描述. 其实 &quot;越靠左上方&quot; 可以用 ROC曲线下的面积(Area Under the Curve,  AUC) 来描述,  如果下边面积越大,  就说明 &quot;越靠左上方&quot;,  当然如果面积最大到 1,  那就是完美的分类器. 因为此时对任意的阈值 $TPR == 1$ ,  即无论什么阈值,  所有的正样本都能够正确识别,  那就只有一种情况:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/12/IZ8K9YldJj7xHyG.png&quot; alt=&quot;image.png&quot; width=&quot;300&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;下面我们从另外的角度看一下 AUC. 首先回顾 $TPR$ 和 $FPR$.&lt;/p&gt;
&lt;p&gt;TPR : 预测为正样本,  且本身是正样本的占所有本身是正样本的比例,  即给一个正样本 $X$,  模型预测为正样本的概率 $P(X)$.&lt;/p&gt;
&lt;p&gt;FPR : 预测为正样本,  且本身是负样本的占所有本身是负样本的比例,  即给一个负样本 $Y$,  模型预测为正样本的概率 $P(Y)$.&lt;/p&gt;
&lt;p&gt;那对于 ROC 的一个点 (FPR,  TPR),   假设 TPR &amp;gt; FPR 时(这是我们希望的),  表明给定任意一对儿 (正, 负) 样本 (X, Y),  模型预测结果 P(X) &amp;gt; P(Y),  即本身为正样本的预测输出值 &amp;gt; 本身为负样本的预测输出值.&lt;/p&gt;
&lt;p&gt;:::tip
再换句话说,  &lt;strong&gt;假设利用模型的输出对所有样本进行降序排序,  那么排序后的结果,  本身是正样本能排在本身是负样本的前边(以一定概率)&lt;/strong&gt;.  而AUC，作为ROC曲线下的面积，是在所有决策阈值下的概率积分，从而代表了模型在任意阈值下,  对随机选择的 (正, 负) 样本对的排序能力。
:::&lt;/p&gt;
&lt;h2&gt;3. AUC的计算&lt;/h2&gt;
&lt;p&gt;显然通过计算曲线下面积的方式要用到积分,  这个可能比较棘手,  我们可以利用另外一种含义的性质来计算.&lt;/p&gt;
&lt;h3&gt;3.1 算法1&lt;/h3&gt;
&lt;p&gt;思想 :  我们想评估 &lt;code&gt;模型对任意一对儿 (正, 负) 样本 (X, Y),  模型预测结果 P(X) &amp;gt; P(Y),  即本身为正样本的预测输出值 &amp;gt; 本身为负样本的预测输出值 的能力(即概率)&lt;/code&gt; ,  将这个进行拆解: 对每一个 正样本遍历,  观测当前正样本 排在 多少个负样本前边,  然后累计,  最后除以总的可排列组合数,  即可得到 &quot;对随机选择的 (正, 负) 样本对 (X, Y) 的 P(X) &amp;gt; P(Y) 排序能力(即概率).&quot;&lt;/p&gt;
&lt;p&gt;举个例子:&lt;/p&gt;
&lt;p&gt;&amp;lt;div style=&quot;text-align:center;&quot;&amp;gt;
&amp;lt;table&amp;gt;
&amp;lt;thead&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;th&amp;gt;class&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;label&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;pre&amp;lt;/th&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;/thead&amp;gt;
&amp;lt;tbody&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;A&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.1&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;B&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.4&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;C&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.3&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;D&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.8&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;/tbody&amp;gt;
&amp;lt;/table&amp;gt;&lt;/p&gt;
&lt;p&gt;总共 2个正样本,  2个负样本,  共 2 * 2 种排列组合&lt;/p&gt;
&lt;p&gt;对于正样本C,  其在 1 个负样本前边.&lt;/p&gt;
&lt;p&gt;对于正样本D,  其在 2 个负样本前边.&lt;/p&gt;
&lt;p&gt;故该模型的AUC为:&lt;/p&gt;
&lt;p&gt;$$
AUC = \frac {1 + 2} {4} = 0.75
$$&lt;/p&gt;
&lt;p&gt;如果遇见正负样本输出得分一样的呢？将一样的认为是0.5个&lt;/p&gt;
&lt;p&gt;&amp;lt;div style=&quot;text-align:center;&quot;&amp;gt;
&amp;lt;table&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;th&amp;gt;class&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;label&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;pre&amp;lt;/th&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;A&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.1&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;B&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.4&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;C&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.4&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;D&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.8&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;/table&amp;gt;&lt;/p&gt;
&lt;p&gt;总共 2个正样本,  2个负样本,  共 2 * 2 种排列组合&lt;/p&gt;
&lt;p&gt;对于正样本C,  ABC 和 ACB 顺序都可以, 所以理解为在 1.5 个负样本前边.&lt;/p&gt;
&lt;p&gt;对于正样本D,  其在 2 个负样本前边.&lt;/p&gt;
&lt;p&gt;故该模型的AUC为:&lt;/p&gt;
&lt;p&gt;$$
AUC = \frac {1.5 + 2} {4} = 0.875
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;因为这个算法要遍历正样本,  然后与负样本比较计数,  因此复杂度属于 $O(N^2)$ .&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;3.2 算法2&lt;/h3&gt;
&lt;p&gt;既然我们需要衡量模型的排序能力,  那不妨先对样本按照模型预测值排个序, 如下表&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;假设 $M$ 个正样本,  $N$ 个负样本&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;div style=&quot;text-align:center;&quot;&amp;gt;
&amp;lt;table&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;th&amp;gt;class&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;label&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;pre&amp;lt;/th&amp;gt;
&amp;lt;th&amp;gt;rank&amp;lt;/th&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;A&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;B&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.4&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;2&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;C&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.4&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;3&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;tr&amp;gt;
&amp;lt;td&amp;gt;D&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;1&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;0.8&amp;lt;/td&amp;gt;
&amp;lt;td&amp;gt;4&amp;lt;/td&amp;gt;
&amp;lt;/tr&amp;gt;
&amp;lt;/table&amp;gt;&lt;/p&gt;
&lt;p&gt;根据上表的 $RANK$ 可以很容易知道以下成立:&lt;/p&gt;
&lt;p&gt;第 1 个正样本 C 的 $rank = 3$ ,  C 的前边有 2 个负样本 : (不算自己: $rank - 1 = 2$)&lt;/p&gt;
&lt;p&gt;第 2 个正样本 D 的 $rank = 4$ ,  D 的前边有 2 个负样本 : (不算自己和前一个正样本: $rank - 2 = 2$)&lt;/p&gt;
&lt;p&gt;同理,  假设对于第 M 个正样本 E ,  其 $rank = K$,  则 E 的前边有 $K - M$ 个负样本 .&lt;/p&gt;
&lt;p&gt;这样我们就有简单的计算方式,  去计算每个正样本盖过多少个负样本,  从而 AUC 如下:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
AUC &amp;amp;= \frac {(rank_{x_1} - 1) + (rank_{x_2} - 2) + ... +  (rank_{x_M} - M)} {M \times N} \
&amp;amp;= \frac {(rank_{x_1} + rank_{x_2} + ... + rank_{x_M}) - (1 + 2 + ... + M)} {M \times N} \
&amp;amp;= \frac {\sum_{i=1}^{M} rank_{x_i} + \frac {M(M+1)} {2}} {M \times N}
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Note : 如果出现预测值相等的情况,  这个时候的 rank 是不确定的,  比如下表结果,  对与 B 样本,  其 $pre = 0.5$ ,  和他一样的有 $4$ 个,  这样对于 B 样本,  其可能的 rank 可以是 ${2, 3, 4, 5} $,  所以其实际所发挥的 rank 作用 为 : $\frac {2+3+4+5} {4}$ ,  正样本 C 同理 .&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/12/KinkoBEXAQvdCcg.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;400&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;4. Group AUC (GAUC) 的含义&lt;/h2&gt;
&lt;p&gt;AUC 在传统的机器学习二分类中还是很能打的，但是有一种场景，虽然是分类模型，但是却不适用 AUC，即广告推荐领域.&lt;/p&gt;
&lt;p&gt;当商品库有多个商品要推荐给你的时候，其实算法并不关心每个商品值得推荐的概率是否够高，具体的业务中，我们只关心要推荐给你的商品的排序是否有效. 即更加关注排序.&lt;/p&gt;
&lt;p&gt;这个时候就有一个问题,  如下表:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/12/mo6cW8Vq14gj7kH.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;对于用户A和B分别来看, 模型对每个item给出的推荐顺序(或者概率)都是符合的 都是可以能够正确分类结果 ( 当然这里可能分类正确与否不是很重要 ),  能够在每个用户身上区分开的 .&lt;/p&gt;
&lt;p&gt;每个用户的AUC都是1 ,  但是如果把用户A和用户B一起来看,  当成一个用户,   这时候模型对 item 的预测, 给出了不一样的顺序,  这是混合的 AUC = (4 + 4 + 2 + 2) / 16 = 3/4 = 0.75&lt;/p&gt;
&lt;p&gt;Group AUC (GAUC)  就是用来解决这个问题,  即通过将不同的用户分组 然后加权计算,  实际中,  权重可以是 不同用户的 click次数 ,  基于时间的加权,  基于位置的加权等等.&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/12/H5BEwOa3VplQM2g.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://www.jianshu.com/p/f9f8e29abbe0&quot;&gt;https://www.jianshu.com/p/f9f8e29abbe0&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://medium.com/@j.zh/from-auc-to-gauc-928e1c4f1fc9&quot;&gt;https://medium.com/@j.zh/from-auc-to-gauc-928e1c4f1fc9&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>GAN Loss Derivation</title><link>https://xuchenhui.cc/posts/2024-04-11-gan-loss-derivation/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-11-gan-loss-derivation/</guid><description>从 Generator 与 Discriminator 的目标出发，逐步推导 GAN 损失函数，并说明 min-max 形式背后的直觉。</description><pubDate>Thu, 11 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;&lt;a href=&quot;https://arxiv.org/abs/1406.2661&quot;&gt;GAN  原始  paper&lt;/a&gt;  中的损失很优美:&lt;/p&gt;
&lt;p&gt;$$
\mathcal{L}&lt;em&gt;{\text{GAN}} = min&lt;/em&gt;{G} \ max_{D} \  \mathbb{E}&lt;em&gt;{x \sim p&lt;/em&gt;{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$&lt;/p&gt;
&lt;p&gt;不过有的同学可能看的一头雾水,  我们来推导一下怎么来的.&lt;/p&gt;
&lt;h2&gt;1. 推导&lt;/h2&gt;
&lt;p&gt;为方便推导 ,  记 &lt;code&gt;Generator&lt;/code&gt; 为 &lt;code&gt;G&lt;/code&gt; ,  &lt;code&gt;Discriminator&lt;/code&gt; 为 &lt;code&gt;D&lt;/code&gt;.&lt;/p&gt;
&lt;h3&gt;1.1 Generator&lt;/h3&gt;
&lt;p&gt;Generator 要做的事情呢 ,  可以划分为以下几步:&lt;/p&gt;
&lt;p&gt;[1] 首先,  从一个 noise 分布 sample 一笔数据 ,  不妨假设 $z \sim p_z(z)$&lt;/p&gt;
&lt;p&gt;[2] 然后 Generator 一顿操作,  输出 $G(z)$&lt;/p&gt;
&lt;p&gt;[3] 目标: 尽可能的欺骗 Discriminator ,  让其认为  $G(Z)$  是真的 ,  具体表现为 $D(G(Z))$ 越接近 $1$ 越好&lt;/p&gt;
&lt;p&gt;因此,  用交叉熵表示 Generator 要优化的目标是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
L(G) &amp;amp;=  minimize \ \sum 1 * \frac {1} {log(D(G(z)))} + 0 * \frac {1} {log(1 - D(G(z)))} \
&amp;amp;= minimize \ \sum 1 * \frac {1} {log(D(G(z)))} \
&amp;amp;= minimize \ - \sum log(D(G(z))) \
&amp;amp;= minimize \ \sum log(1 - D(G(z))) \
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;这里有个小trick ,  当我们 update Generator 时 ,  Discriminator 是固定的 ,  而 $x \sim p_{\text{data}}(x)$ 也是固定的 (就是我们真实样本训练集),  于是 Generator 有以下&lt;strong&gt;等价优化目标&lt;/strong&gt;(当然可以加个 Expectation ~ ).&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
\mathcal{L}&lt;em&gt;{\text{G}} = min&lt;/em&gt;{G}\  \mathbb{E}&lt;em&gt;{x \sim p&lt;/em&gt;{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
\end{align*}
$$&lt;/p&gt;
&lt;h3&gt;1.2 Discriminator&lt;/h3&gt;
&lt;p&gt;Discriminator 要做的事情呢 ,  可以划分为以下几步:&lt;/p&gt;
&lt;p&gt;[1] 首先,  从一个 真实 分布 sample 一笔数据 ,  不妨假设 $x \sim p_x(x)$&lt;/p&gt;
&lt;p&gt;[2] 然后,  接受来自 Generator 的输出 $G(Z)$&lt;/p&gt;
&lt;p&gt;[3] 将 $x$ 和 $G(Z)$ 都扔给 Discriminator&lt;/p&gt;
&lt;p&gt;[4] 目标: 尽力分辨出 $x$ 为真,  $G(Z)$ 为假.&lt;/p&gt;
&lt;p&gt;因此,  用交叉熵表示 Discriminator 要优化的目标是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
L(D) &amp;amp;=  minimize \ \sum { 1 * \frac {1} {log(D(x))} + 0 * \frac {1} {log(1 - D(x))} }&lt;em&gt;{x\  for\  true} \ \
&amp;amp;+ \ \sum  {0 * \frac {1} {log(D(G(z)))} + 1 * \frac {1} {log(1 - D(G(z)))} }&lt;/em&gt;{G(z)\  for \ false} \
&amp;amp;= minimize \ \sum 1 * \frac {1} {log(D(x))} \ + \ \sum 1 * \frac {1} {log(1 - D(G(z)))}  \
&amp;amp;= minimize \ - \sum log(D(x)) - \sum log(1 - D(G(z))  \
&amp;amp;= maximize \ \sum log(D(x)) \ + \ \sum log(1 - D(G(z))  \
\end{align*}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;啊,  美化一下,  加个 Expectation~ , 美滋滋&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
\mathcal{L}&lt;em&gt;{\text{D}} = max&lt;/em&gt;{D}\  \mathbb{E}&lt;em&gt;{x \sim p&lt;/em&gt;{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
\end{align*}
$$&lt;/p&gt;
&lt;h3&gt;1.3 大一统&lt;/h3&gt;
&lt;p&gt;$Generator$ 要 $minimize$ 下边这个式子,  $Discriminator$ 要 $maximize$ 下边这个式子 . 叮~ 任务完成~&lt;/p&gt;
&lt;p&gt;$$
\mathcal{L}&lt;em&gt;{\text{GAN}} = min&lt;/em&gt;{G} \ max_{D} \  \mathbb{E}&lt;em&gt;{x \sim p&lt;/em&gt;{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$&lt;/p&gt;
&lt;h2&gt;2. 算法步骤&lt;/h2&gt;
&lt;p&gt;贴一个原始paper中的算法步骤,  不过可以看到 ,  上边式子那个只是为了美观 ,  实际更新的时候,  还是用原始的,&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/12/yS5QvjER17fJ3z4.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>Sampling Method</title><link>https://xuchenhui.cc/posts/2024-04-11-sampling-method/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-11-sampling-method/</guid><description>介绍蒙特卡洛采样方法及其在参数估计中的应用，涵盖逆变换采样、拒绝采样等核心技术的原理与多臂老虎机场景下的实践。</description><pubDate>Thu, 11 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;比如在老虎机场景,  我们想知道哪一台老虎机的赢面更大,  通常是给定所有老虎机 &quot;赢&quot; 的参数分布 ,  比如 Dirichlet distribution,  初始化 $\alpha1 \ \alpha2 \ …$  ,  然后根据实际数据采样,  更新 Dirichlet distribution 的参数即可.&lt;/p&gt;
&lt;p&gt;具体采样流程(通常使用在类似多臂老虎机场景) :&lt;/p&gt;
&lt;p&gt;[1] 首先假设 参数p的先验分布 (比如 beta 分布 $B(m, n)$,  Dirichlet 分布 $D(a, b, c, ..., z)$)&lt;/p&gt;
&lt;p&gt;[2] 然后 &lt;strong&gt;基于该分布 ,   采样一组参数(就是各个机器的成功概率)&lt;/strong&gt; ,  然后基于当前的参数抽卡,  并选择最大的p对应的老虎机作为成功case ,  然后观察其结果,  并更新对应参数(比如实际是另外一个老虎机赢了). 重复此步骤.&lt;/p&gt;
&lt;p&gt;:::note
&lt;strong&gt;这里就会涉及到一个问题,  对参数采样,  怎么采才能尽可能的符合、或者接近参数本身的分布&lt;/strong&gt;？
:::&lt;/p&gt;
&lt;h2&gt;1. 基于Monte-Carlo的方法&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;引理1&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;设 X 是一个随机变量，其分布函数$f(x)$,  累积分布函数 (CDF,  Cumulative distribution function) 为 F(x) ,  该函数是一个单调递增的函数,  其值域为[ 0 ,   1 ]. 现在定义一个新的随机变量$Y = F(X) $ ,  则 随机变量 $Y$ 的分布是均匀分布.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;证明&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;对于任意实数$y$ ,  我们有:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P(Y&amp;lt;=y)  = P(F(X) &amp;lt;= y) = P(X &amp;lt;= F^{-1}(y) = F( F^{-1}(y))
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;由于F(x)是单调递增函数, 因此$F^{-1}(y)$具有唯一解 $x$ , 令$x = F^{-1}(y)$ , 则有 $F(x) = y$ .&lt;/p&gt;
&lt;p&gt;因此&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P( Y &amp;lt;= y) = F(F^{-1}(y)) = F(x) = y
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;即有&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P( Y &amp;lt;= y)  = y
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;即 Y是均匀分布&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;1.1 逆变换采样法&lt;/h3&gt;
&lt;p&gt;设 X 是一个随机变量，其分布函数$f(x)$,  累积分布函数 (CDF,  Cumulative distribution function) 为 F(x). 则依据如下采样过程,  得到的x是服从分布$f(x)$的.&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;从均匀分布 U(0,  1) 中生成一个随机数 u&lt;/li&gt;
&lt;li&gt;计算 F(x) = u 的解 x&lt;/li&gt;
&lt;li&gt;输出 x 作为采样结果&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;证明&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;根据引理1容易知道,  如果从均匀分布 $U(0,  1) $ 中生成一个随机数 u，并令 $x = F^{-1} (u)$，则 $x$ 服从原分布$ F(x)$。(理解为本身这个$F$就是我们想采样的 $f$ 对应的 $F$,  那反函数求解出来的 $x$ 自然就是 满足 $f(x)$ 和 $F(x)$ ) ,  即为 逆变换方法 ,  几个具体实现: &lt;a href=&quot;https://lwz322.github.io/2019/06/02/ITM.html&quot;&gt;https://lwz322.github.io/2019/06/02/ITM.html&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;1.2 拒绝采样法&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;准备工作&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;已知 概率密度函数$f(y)$,  我们需要依据这个分布进行抽样&lt;/li&gt;
&lt;li&gt;找一个&lt;strong&gt;任意能够直接进行采样的分布$g(y)$ (比如均匀分布)&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;找一个常数 $c$,  满足对 $\forall y$ ,  均有 $c \times g(y) &amp;gt;= f(y)$,  即 $c$ 是函数 $\frac {f(y)} {g(y)}$ 的上界 或者 $c \times g(y)$ 能够覆盖 $f(y)$&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;抽样流程&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;从 $g(y)$ 中中随机采样一个样本 $y_i$&lt;/li&gt;
&lt;li&gt;从均匀分布 $U(0, 1)$ 中采样一个随机数 $u_i$&lt;/li&gt;
&lt;li&gt;如果 $u_i &amp;lt;= \frac {f(y_i)} {c * g(y_i)}$  成立,  则保留该样本 $y_i$,  否则返回 step1重复. 可以证明,  这样从 $g(y)$ 抽出的样本 $y_i$ 是满足概率密度函数 $f(y)$ 及其对应的CDF函数&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;证明&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;证明上述采样方法生成的样本服从 $f(y)$ ,  等价于证明以下内容&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/11/zBmECrMeK82fnAI.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;其中 ,  U 为 $[0 , 1]$ 的随机数,  $y$ 是从 $g(y)$ 采样得到的 ,  $F$ 和 $G$ 分别是 $f$ 和 $g$ 对应的累积分布函数.&lt;/p&gt;
&lt;p&gt;根据贝叶斯公式&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P(A|B) = \frac {P(B|A)P(A)} {P(B)}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;将 $P(Y&amp;lt;=y \mid U &amp;lt;= \frac {f(Y)} {c * g(Y)})$ 用贝叶斯公式转化为:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/11/e6lD9zZm7pocfIu.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;现在分别来看 右边的 3个式子&lt;/p&gt;
&lt;p&gt;(1) 分母&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P(U &amp;lt;= \frac {f(Y)} {c&lt;em&gt;g(Y)}) =  \int P(U &amp;lt;= \frac {f(Y)} {c&lt;/em&gt;g(Y)}| Y = y)p( Y = y)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;由于 y 是从 g 中抽样得到的 ,  那么 $p( Y = y) = g(y)$ ,  不妨假设此时 y 的抽样结果 : $Y = y$ ,  又因为 U 是均匀的 0 ,  1 分布 , 按定义 我们有&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P(U &amp;lt;= \frac {f(Y)} {c&lt;em&gt;g(Y)}| Y = y) = \frac {f(y)} {c&lt;/em&gt;g(y)}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;此外,  由于$\int f(y)=1$ ,  我们有&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/11/mdCt9oxAZKaMNsW.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;(2) 分子 $p( Y &amp;lt;= y)$&lt;/p&gt;
&lt;p&gt;按照定义&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
p( Y &amp;lt;= y) = G(y)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;(3) 分子 $P(U &amp;lt;= \frac {f(Y)} {c*g(Y)} \ Y &amp;lt;= y)$&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
P(U &amp;lt;= \frac {f(Y)} {c&lt;em&gt;g(Y)} \mid Y &amp;lt;= y) &amp;amp;=
\frac {P(U &amp;lt;= \frac {f(Y)} {c&lt;/em&gt;g(Y)},  Y &amp;lt;= y)} {P(Y &amp;lt;= y)} \&amp;amp;=
\frac { \int_{-\infty}^{y} P(U &amp;lt;= \frac {f(w)} {c&lt;em&gt;g(w)}  ,  Y = w &amp;lt;= y)\ dw}{G(y)} \ &amp;amp;=
\frac { \int_{-\infty}^{y} \frac {f(w)} {c&lt;/em&gt;g(w)} &lt;em&gt;g(w) \ dw}{G(y)} \ &amp;amp;=
\frac { \frac {F(y)} {c&lt;/em&gt;G(y)} * G(y) }{G(y)} \ &amp;amp;=
\frac {F(y)} {c&lt;em&gt;G(y)}
\end{align&lt;/em&gt;}
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;于是,  原始公式可进行转化,  从而证明完毕:&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
P\big(Y&amp;lt;=y | U &amp;lt;= \frac {f(Y)} {c&lt;em&gt;g(Y)}\big) = \frac { \frac {F(y)} {c&lt;/em&gt;G(y)} * G(y) } {\frac {1} { c }} \ = F(y)
$$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;直觉理解&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;假设复杂分布 $P(z)$ ,  存在常数 $k$ 与 任意分布 $q(z)$ ,  以 $z_0$ 点为例,  画直线,  任意从均匀分布抽取一个点 $u_i$,  可以理解为在 $x = z_0$ 这条直线上取一点: 就是 $u_i  * k * q(z_0)$,  其处于阴影即拒绝 (即 $U * k * q(z_0) &amp;gt; p(z_0)$) , 处于白色区域即接受( $U * k * q(z_0) &amp;lt;= p(z_0)$ ) ,  这样从 $z_0$ 出来的点对应的最大概率就是$ f(z_0) $ , 等价于是从 $f(x)$ 抽样出来的&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/11/XO3GehobsnckrNQ.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;hr /&gt;
&lt;p&gt;上述2个方法都属于Monte-Carlo 方法,  并且是已知 $P(\theta)$ 的情况下 ,  然后在某些特殊场景下,  已知了 参数的后验分布 和 先验分布 的关系(比如之前提到的共轭) , 才能得到一个比较简易的形式 ,  直接对后验分布更新. (当我们面临无法得到具体形式的非共轭后验分布时，我们无法采用这种算法。)&lt;/p&gt;
&lt;p&gt;然而,  面对一些复杂的分布,  即使我们已知了 $P(\theta)$  ,  再利用贝叶斯公式的时候 ,  其分母涉及到积分, 往往也是很难求解的&lt;/p&gt;
&lt;p&gt;$$
P(\theta|X) = \frac {P(X|\theta)P(\theta)} {\int P(X|\theta)P(\theta) d \theta}
$$&lt;/p&gt;
&lt;p&gt;上述提到分母有时候很难进行积分，对于这个问题，一个直观的想法就是 ，能不能通过某个手段把 分母去掉？&lt;/p&gt;
&lt;p&gt;$$
P(\theta_a|X) = \frac {P(X|\theta_a)P(\theta_a)} {P(X)}
$$&lt;/p&gt;
&lt;p&gt;$$
P(\theta_b|X) = \frac {P(X|\theta_b)P(\theta_b)} {P(X)}
$$&lt;/p&gt;
&lt;p&gt;二者做比值&lt;/p&gt;
&lt;p&gt;$$
\gamma = \frac {P(\theta_a|X)}{P(\theta_b|X)} = \frac {P(X|\theta_a)P(\theta_a)}{P(X|\theta_b)P(\theta_b)}
$$&lt;/p&gt;
&lt;p&gt;这样避免了分母的积分，这里 $P(\theta_a)$ 可以参考 Dirichlet Distribution （多维）或者 Beta Distribution （二维）. 思想是这样的,  不过需要一点点其他知识.&lt;/p&gt;
&lt;p&gt;:::note
未完待续...
:::&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
</content:encoded></item><item><title>Gama&amp;Beta&amp;Dirichlet</title><link>https://xuchenhui.cc/posts/2024-04-10-gama-beta-dirichlet/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-10-gama-beta-dirichlet/</guid><description>总结 Gamma 分布、Beta 分布和 Dirichlet 分布的定义、性质及其之间的内在联系，帮助梳理概率分布之间的脉络关系。</description><pubDate>Wed, 10 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;这篇 Blog 主要对几个分布进行总结,  以及对他们之间的关系进行梳理&lt;/p&gt;
&lt;h2&gt;1. Gamma Function&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;定义&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/CP5UfaszbRTnBi8.png&quot; alt=&quot;Gamma函数定义.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;记忆方法: 理解为用一个伽马刀, 对 $t$ 动了一刀, 于是指数为 $\alpha-1$,  动完刀需要扶着梯子 $-t$ 才能走下来。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;性质&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/KRSeYOvgoc73aQ5.png&quot; alt=&quot;Gamma函数性质.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;2. Gamma Distribution&lt;/h2&gt;
&lt;p&gt;对Gamma函数等式左右两端, 同时除以$\Gamma(\alpha)$, 则有&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/1OSXgUjPNF92b5C.png&quot; alt=&quot;Gamma分布.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;于是取积分中的函数作为概率密度, 就得到一个简单的Gamma分布的密度函数：&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/36U2R9JErO7vaiA.png&quot; alt=&quot;Gamma分布密度函数.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;如果做一个变换 $t=\beta x$, 就得到Gamma分布的更一般形式：&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/7So1LMAUml2OEkr.png&quot; alt=&quot;Gamma分布.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;其中 $\alpha$ 称为shape parameter, 主要决定了分布曲线的形状, 而 $\beta$ 称为rate parameter或inverse scale parameter（ $\frac {1} {\beta}$ 称为scale parameter）, 主要决定曲线有多陡。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/zTtJFMWeSVPYlGr.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;3. Binomial Distribution&lt;/h2&gt;
&lt;p&gt;二项分布是 n 次独立的是/非试验中成功的次数的离散概率分布, 其中每次试验的成功概率为p.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;p 是已知的 ,  给定 p , 求解目标成功次数对应的概率,  通常使用 $C(m, n, p)$ 来计算.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;n次试验中正好得到k次成功的概率由概率质量函数给出：
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/HVzXuTD9Ka5UYlM.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;:::note
参数p 本身就有一个分布,  如何预估 参数p 本身的分布 ?  参考下边的 Beta分布
:::&lt;/p&gt;
&lt;h2&gt;4. Binomial Distribution 的另一个理解&lt;/h2&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/tq31VBQ9alucOEf.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;假设向长度为1的桌子上扔一个红球（如上图）, 它会落在0到1这个范围内, 设这个长度值为 x (就是上边定义中的p), 再向桌上扔一个白球, 那么这个白球落在红球左边的概率即为 x (或者p). 若总共扔了n个白球, 每次都是独立的, 假设落在红球左边的白球个数为k,  那么次数 k 在给定参数 x (或者p)的分布为:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/B6849qkCnEarUwd.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;可以看到, 结果就是二项分布. 在这个例子的基础上,  进一步的我们来看 , 如果&lt;strong&gt;不关注&lt;/strong&gt; 概率 p ,  我们来求解泛化下的 k次成功概率,  即 需要对 p 积分 (这里是对x积分)&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/CQsWx1A2lDktw8V.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;这个比较难计算,  我们换个思路,  P(K=k)就是想说&lt;strong&gt;总共n个白球, 1个红球随便放, 然后红球左边的白球个数为 k 的概率&lt;/strong&gt;. 而 红球的位置是未知的&lt;/p&gt;
&lt;p&gt;ok, 现在假设k=1, 换句说,  理解为总共n+1个球,   把第2个球涂为红色(左边就1个白色): $P = \frac {1} {n+1}$&lt;/p&gt;
&lt;p&gt;ok, 现在假设k=2, 换句说,  理解为总共n+1个球,   把第3个球涂为红色(左边就2个白色): $P = \frac {1} {n+1}$&lt;/p&gt;
&lt;p&gt;...&lt;/p&gt;
&lt;p&gt;同理,  即得&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/CGDNJqa5zn3S1AT.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;5. Beta Function&lt;/h2&gt;
&lt;p&gt;在上边的式子基础上,  令 $k = \alpha - 1$ ,  $n - k = \beta - 1$ ,  则 $n = \alpha + \beta - 2$ ,  变量换为 $t$&lt;/p&gt;
&lt;p&gt;则有&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/STYmzEcR9wQs4pV.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;定义 $Beta$ 函数&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/b8duz7lpE5vGVCs.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;根据之前 $Gama(\alpha)$ 函数的定义 ,  $Beta(\alpha, \beta)$ 可以表示为 :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/pkcCZHqhrFNwxmV.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;6. Beta Distribution&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;定义 $Beta$ 分布
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/5PQtVqLlzJxsRCD.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;其中 $Beta(\alpha, \beta)$ 起到归一化的作用&lt;/p&gt;
&lt;p&gt;:::note
The Beta distribution is the conjugate prior for the Bernoulli, binomial, negative binomial and geometric distributions (seems like those are the distributions that involve success and failure) in Bayesian inference. see what&apos;s means that &quot;&lt;a href=&quot;https://stats.stackexchange.com/questions/58564/help-me-understand-bayesian-prior-and-posterior-distributions/58792#58792&quot;&gt;conjugate prior&lt;/a&gt;&quot; ?
:::&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;性质&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$Beta$ 分布 与 之前的 &lt;code&gt;Bernoulli&lt;/code&gt; 分布 (0 -1 分布) , &lt;code&gt;Binomial&lt;/code&gt; 分布 (即二项分布 , $C(m, n, p)$ ) 构成共轭分布.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;[1] &lt;code&gt;Binomial&lt;/code&gt; 分布(二项分布 ) 理解为给定成功概率参数 $p$ 和实验次数 $n$ 的情况下,  成功 $k$ 次的概率分布.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;[2] Beta分布是在给定成功次数 $\alpha$ 和失败次数 $\beta$ (通常来自实验观察,  $p$ 未知) 后 ,  探究成功概率参数 $p$ 的分布 (即上述公式中的 x 的分布) . 看起来像是一种对偶的关系,  不过更多人叫他是共轭关系&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;7. Multinomial Distribution&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;定义&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/ZX3Y9TQ8Sx2OIF4.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;[1] 这里 p 是已知的 ,  给定 p ,  求解目标成功次数对应的概率. 对于此时参数p本身的分布 , 参考下方 Dirichlet Distribution&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;:::note
[2] 这里可以看到,  Multinomial Distribution 其实可以理解为多维的 “二项分布” ,  即 “多项分布”.
:::&lt;/p&gt;
&lt;p&gt;上述函数, 如果使用 $\Gamma$ 函数表示 :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/beKZFcOQVUYS5Dp.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;h2&gt;8. Multinomial Distribution的另一种理解&lt;/h2&gt;
&lt;p&gt;可以结合之前的小球案例,  多项分布 可以理解为如下案例&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/e4uyMEb7FdQNKUg.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;将 $n$ 个球 放到不同的箱子中,  每个箱子分到的小球的个数分别是 $x_1, x_2\ ...\ x_n$ 的概率,  进一步转换 ,  可以理解为把小球一字排开, 然后在中间放隔板(在哪里放隔板是根据一定概率$p_i$),  使得间隔内的球数目为 $x_1, x_2\ ...\ x_n$ ,  然后得到&lt;/p&gt;
&lt;p&gt;$$
f(x_1 ,  \ ... \ , x_k,  n ,  p_1, p_2 \ ... \ p_n)\ = \ C*p_1^{x_1} * \ ... \ p_n^{x_n}
$$&lt;/p&gt;
&lt;p&gt;前边的系数C,  我们可以这样理解,  因为我们关注的是每个箱子中小球的具体数量,  不关注箱子内小球的顺序 . 所以我们可以先假设是有顺序的, 然后再把顺序删除就能算得相应的小球放置情况的数量&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;假设有顺序,  那么首先 n 个小球全排列 : $n!$ (有顺序)&lt;/li&gt;
&lt;li&gt;然后删除顺序, 只不过将这个过程分配给每个箱子内部实现 : 每个箱子内部除以 $x_i{!}$ 即可, 便得到上述结果&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;9. Dirichlet Distribution&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;定义&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/JLZRe6PzWrCmAS5.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;其中 $Beta(\alpha)$ 还是起到归一化的作用,  这里的 $x$ 就是参数 $p$ ,  其中&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/jRQwx6fGzbdcmT1.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;同理,  使用 $\Gamma$ 函数 可得到 $Beta(\alpha)$ 的另一种表示&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/1mqPgNeivSyKR8r.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;:::note
可以看到 &lt;code&gt;Multinomial Distribution&lt;/code&gt; 和 &lt;code&gt;Dirichlet Distribution&lt;/code&gt; 的关系 类似 &lt;code&gt;Binomial Distribution&lt;/code&gt; 和 &lt;code&gt;Beta Distribution&lt;/code&gt; 的关系 . 这里 &lt;code&gt;Multinomial Distribution&lt;/code&gt; 是给定各个|箱子|板子|老虎机|成功概率 $p$ , 然后去求解不同成功次数对
应的概率.而 &lt;code&gt;Dirichlet Distribution&lt;/code&gt; 要做的是,  根据成功次数(或者已知成功次数) $\alpha$ 去探讨每个|箱子|板子|老虎机|成功概率,  或者可以说 把 &lt;code&gt;Beta Distribution&lt;/code&gt; 和 &lt;code&gt;Dirichlet Distribution&lt;/code&gt; 作为了成功概率的先验分布 .
:::&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;性质&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;在老虎机场景下. 假设离散随机变量 $X$ (可以理解为每个Bandit成功的次数)&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/DUZz4WqcwAkCMIg.png&quot; alt=&quot;image.png&quot; width=&quot;100&quot; height=&quot;130&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;令各Bandit的成功概率为&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/d8biLJZpeyh6N9r.png&quot; alt=&quot;image.png&quot; width=&quot;100&quot; height=&quot;130&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;那随机变量 $X$ 的分布为 (就是多项分布)&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/aiI1wQcuFTCPGDZ.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;那参数 $p$ 的先验分布就是 Dirichlet Distribution&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;应用&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;如下案例,  类似Bandit,  只不过使用筛子.&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/NSyKDk29Fu8Vql6.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;现在假设不知道每个 bandit 的成功概率,  就是说不知道骰子每个面朝上的概率, 我们要通过试验来得到这些概率, 那我们就会有以下内容:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/wlBVrdCth17y6Go.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;其中,&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/OhzY7IS8T3LK6mE.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;相应的发生次数概率计算结果&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/bpCENXRI2jQkr4Z.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;那么根据实验结果来预估潜在的 $p$ (就是后验分布),  可以使用Bayes rule:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/fK29OLBqdgm4vR5.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;其中,  分母 $P(X=m)$&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/11/g32j7IYrqBasc9F.png&quot; alt=&quot;image.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;然后分子中的 $f_p(P)$ 就是上边的先验分布 Dirichlet Distribution ,  整体带入即得&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/Omz684tcdPBZixp.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;这里 c 取到归一化的作用:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/ebsgrm3vEq57QYO.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;可以看到, &lt;strong&gt;潜在参数 $p$ 的后验分布仍然是 Dirichlet distribution&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/Gl8dEhwBxR3TKHg.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;既有&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/11/DcgUu9rbPnyZIAL.png&quot; alt=&quot;image.png&quot; width=&quot;400&quot; height=&quot;300&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;那么对于 每个Bandit或者筛子,  先初始化 Dirichlet distribution ,  然后根据Bandit或者筛子的结果,  更新其相应的概率分布即可.
然后再根据当前的概率分布进行下一步采样(就是基于当前概率分布,  选择认为赢得概率更高的Bandit,  然后看结果,  循环更新分布).&lt;/p&gt;
&lt;p&gt;:::note
当然这里就会涉及到采样,  即怎么快速高效的根据已有的分布采样? 这是另外的问题了.... 可以移步至&lt;a href=&quot;/posts/Sampling-Method/&quot;&gt;Sampling-Method&lt;/a&gt;
:::&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://zhuanlan.zhihu.com/p/37976562&quot;&gt;https://zhuanlan.zhihu.com/p/37976562&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://zhuanlan.zhihu.com/p/69606875&quot;&gt;https://zhuanlan.zhihu.com/p/69606875&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[3] &lt;a href=&quot;https://readmedium.com/en/https:/towardsdatascience.com/dirichlet-distribution-the-underlying-intuition-and-python-implementation-59af3c5d3ca2&quot;&gt;https://readmedium.com/en/https:/towardsdatascience.com/dirichlet-distribution-the-underlying-intuition-and-python-implementation-59af3c5d3ca2&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>Probability Calibration</title><link>https://xuchenhui.cc/posts/2024-04-10-probability-calibration/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-10-probability-calibration/</guid><description>介绍分类模型中的概率校准方法，包括校准曲线的绘制以及逻辑回归、贝叶斯、随机森林等模型在校准表现上的差异与原因分析。</description><pubDate>Wed, 10 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;在分类的时候, 我们不仅希望预测类别, 还希望输出概率, 但是有些模型是不能直接输出概率的, 或者输出的概率只是一个相对的, 这时就需要校准&lt;/p&gt;
&lt;p&gt;一个良好的、校准过的分类器, 输出的 prob=0.8,  就是可以理解为当前样本有80%的概率是正样本&lt;/p&gt;
&lt;h2&gt;1. 校准曲线&lt;/h2&gt;
&lt;p&gt;将预测值升序排序, 然后划分bin, “ x-axis represents the average predicted probability in each bin” . 而y轴则是对应bin中, 相应样本是正样本的比例, 然后绘制相应曲线&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://s2.loli.net/2024/04/10/APQRTWiVeMFEYgU.png&quot; alt=&quot;校准曲线.png&quot; /&gt;&lt;/p&gt;
&lt;p&gt;对于条形图的解释&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;逻辑回归的结果很不错, 几乎是可以开箱即用, 是因为本身其loss就是交叉熵, 使用的概率.或者从另一个角度就是最大似然估计.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;贝叶斯看起来更加倾向于将输出close to 0 or 1,  主要可能是因为（&lt;strong&gt;存疑&lt;/strong&gt;）其假设特征是独立的.&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;随机森林有一个明显的特点就是, 输出close 0.1 or 0.9 ,  Niculescu-Mizil and Caruana &lt;a href=&quot;https://scikit-learn.org/stable/modules/calibration.html#id14&quot;&gt;[3]&lt;/a&gt; 认为, 由于随机森林是bagging模型, 如果想让一个样本严格输出为0, 那么就意味着, 所有的base 决策树都要预测这个样本为0, 这通常是不可能的, 因为单个树具有高方差（可能会引入噪声）, 以及最后预测输出的时候是多个树average, 所以结果通常来说不会是 0 or 1&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;什么是高方差？&lt;/p&gt;
&lt;p&gt;通常来说, 高方差指的是, 模型相对复杂, 从而完美的匹配了训练集的数据, 只学到了局部的模型（距离“平均模型”比较远, 方差大）, 即过拟合, 这使得模型对数据很敏感（因为数据可能也是局部的）.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;为什么单个树的方差很高？&lt;/p&gt;
&lt;p&gt;单个决策树具有高方差的原因主要与其自身的结构和学习方式有关.以下是一些导致决策树高方差的关键因素(Chatgpt语气)：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;不剪枝&lt;/strong&gt;：决策树在构建过程中会持续分裂节点, 直到满足某个停止条件.如果决策树没有适当的剪枝策略, 它会继续生长并尝试完美地拟合训练数据, 包括数据中的噪声和异常值.这种过拟合行为会导致模型对训练数据的微小变化非常敏感, 从而增加了方差.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;贪婪搜索&lt;/strong&gt;：决策树的构建通常基于贪婪搜索策略, 这意味着在每个节点上, 它会选择局部最优的特征进行分裂, 而不是考虑全局最优解.这种局部最优选择可能导致树过于复杂, 进而增加了模型的方差.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;数据噪声和异常值&lt;/strong&gt;：由于决策树是基于数据特征进行分裂的, 数据中的噪声和异常值可能会对树的结构产生不成比例的影响.这些数据点可能会导致树在错误的方向上进行分裂, 从而增加了模型的方差.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;特征选择的随机性&lt;/strong&gt;：在构建决策树时, 通常会从特征集中选择一个特征进行节点分裂.如果没有适当的随机性引入, 即使是很小的数据变化也可能导致选择不同的特征, 从而产生完全不同的树结构, 增加了模型的方差.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;树的深度&lt;/strong&gt;：决策树的深度也会影响其方差.树越深, 模型就越有可能学习到数据中的噪声和偶然规律, 从而导致高方差.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;数据表示的选择&lt;/strong&gt;：决策树对数据的表示非常敏感.如果数据的特征没有经过适当的预处理和特征工程, 可能会导致树对数据的某些特定表示过度拟合, 从而增加方差.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;当这些高方差的树被集成在一起时，由于它们的随机性质，它们彼此之间存在差异。这种差异导致它们在某些数据点上出现错误，但在其他数据点上正确，因此这些错误会相互抵消。通过集成多个具有高方差的模型，随机森林能够平均化这些错误，从而降低整体的方差。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;SVM则是属于那种“差不多就行”、“能分类正确即可”, 因此模型输出大多数在0.5 左右&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;2. Calibrating a classifier&lt;/h2&gt;
&lt;h3&gt;2.1 函数介绍&lt;/h3&gt;
&lt;p&gt;使用 &lt;a href=&quot;https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html#sklearn.calibration.CalibratedClassifierCV&quot;&gt;CalibratedClassifierCV&lt;/a&gt; 来实现校准 , 这个类使用交叉验证来校准模型.首先要注意的是, 校准模型时使用的训练集, 不能和 用来训练未校准模型的训练集一样.同时也是为了样本分布平衡, 需要让校准模型的过程在不同训练子集上重复.其核心思想为：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;When &lt;code&gt;ensemble=True&lt;/code&gt; (default)
&lt;ul&gt;
&lt;li&gt;data is split into k &lt;code&gt;(train_set, test_set)&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;然后  CCV类中的&lt;code&gt;base_estimator&lt;/code&gt;（比如决策树）,  首先独立的复制k份, 分别在相应的 &lt;code&gt;train_set&lt;/code&gt; 上进行训练, 然后在相应的&lt;code&gt;test_set&lt;/code&gt; 的预测结果, 会进一步被用来 fit a calibrator (either a sigmoid or isotonic regressor).  each calibrator maps the output of its corresponding classifier into [0, 1].&lt;/li&gt;
&lt;li&gt;fit 好后的 calibrator 存在&lt;code&gt;calibrated_classifiers_&lt;/code&gt; attribute, where each entry is a calibrated classifier with a &lt;a href=&quot;https://scikit-learn.org/stable/glossary.html#term-predict_proba&quot;&gt;predict_proba&lt;/a&gt; method that outputs calibrated probabilities.&lt;/li&gt;
&lt;li&gt;然后 CCV 这个类本身有一个函数：&lt;a href=&quot;https://scikit-learn.org/stable/glossary.html#term-predict_proba&quot;&gt;predict_proba&lt;/a&gt; ,  调用时结果为：average of the predicted probabilities of the &lt;code&gt;k&lt;/code&gt; estimators in the &lt;code&gt;calibrated_classifiers_&lt;/code&gt; list.&lt;/li&gt;
&lt;li&gt;The output of &lt;a href=&quot;https://scikit-learn.org/stable/glossary.html#term-predict&quot;&gt;predict&lt;/a&gt; is the class that has the highest probability.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;when  &lt;code&gt;ensemble=False&lt;/code&gt;
&lt;ul&gt;
&lt;li&gt;一眼丁真, 鉴定为最好别用&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;2.2 校准方法&lt;/h3&gt;
&lt;h4&gt;2.2.1 sigmoid 方法&lt;/h4&gt;
&lt;p&gt;$$
p(y_i = 1 | f_i) = \frac{1}{1 + \exp(A f_i + B)}
$$&lt;/p&gt;
&lt;p&gt;A and B are real numbers to be determined when fitting the regressor via maximum likelihood.&lt;/p&gt;
&lt;p&gt;该方法适用于&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;calibration error is symmetrica
&lt;ul&gt;
&lt;li&gt;meaning the classifier output for each binary class is normally distributed with the same variance&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;This can be a problem for highly imbalanced classification problems, where outputs do not have equal variance.&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;small sample sizes&lt;/li&gt;
&lt;li&gt;un-calibrated model is under-confident and has similar calibration errors for both high and low outputs.&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;2.2.2 isotonic 方法&lt;/h4&gt;
&lt;p&gt;fits a non-parametric isotonic regressor, which outputs a step-wise non-decreasing function, see &lt;a href=&quot;https://scikit-learn.org/stable/modules/classes.html#module-sklearn.isotonic&quot;&gt;sklearn.isotonic&lt;/a&gt;. It minimizes:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;该方法输出map function 是 严格单调递增的 : &lt;strong&gt;即未校准的模型 认为 p(A) &amp;lt; p(B) , 那么经过校准后的的 p(A) 应该还是小于 p(B)&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\sum_{i=1}^{n} (y_i - \hat{f}_i)^2
$$&lt;/p&gt;
&lt;p&gt;subject to&lt;/p&gt;
&lt;p&gt;$$
\hat{f}_i \geq \hat{f}_j
$$&lt;/p&gt;
&lt;p&gt;whenever&lt;/p&gt;
&lt;p&gt;$$
f_i \geq f_j
$$&lt;/p&gt;
&lt;p&gt;$y_i$ is the true label of sample and $\hat{f}_i$ is the output of the calibrated classifier for sample
(i.e., the calibrated probability).&lt;/p&gt;
&lt;p&gt;Overall, ‘isotonic’ will perform as well as or better than ‘sigmoid’ when there is enough data (greater than ~ 1000 samples) to avoid overfitting&lt;/p&gt;
&lt;p&gt;It is not advised to use isotonic calibration with too few calibration samples &lt;code&gt;(&amp;lt;&amp;lt;1000)&lt;/code&gt; since it tends to overfit.&lt;/p&gt;
&lt;h2&gt;References&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://scikit-learn.org/stable/modules/calibration.html&quot;&gt;Sklearn-calibration&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://zhuanlan.zhihu.com/p/502959226&quot;&gt;模型校准(Calibration of Models)技术&lt;/a&gt;&lt;/p&gt;
</content:encoded></item><item><title>Perplexity</title><link>https://xuchenhui.cc/posts/2024-04-10-perplexity/</link><guid isPermaLink="true">https://xuchenhui.cc/posts/2024-04-10-perplexity/</guid><description>深入浅出地解释 NLP 中困惑度（Perplexity）的概念，理解它如何衡量语言模型对样本的预测能力，以及其与概率的关系。</description><pubDate>Wed, 10 Apr 2024 00:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;0. 前言&lt;/h2&gt;
&lt;p&gt;在 NLP 中,  经常可以看到使用&quot;困惑度&quot;来描述一个 LLM 的能力. 那么什么是&quot;困惑度&quot;?&lt;/p&gt;
&lt;p&gt;简单理解,  困惑度就是&quot;模型对样本预测结果的信心&quot;. 具体的,  模型对这个样本结果的预测概率越高,  表明信心越高,  对应困惑度越低.&lt;/p&gt;
&lt;p&gt;:::note
本文介绍的Perplexity 特指 &quot;Perplexity of a probability model&quot;.
:::&lt;/p&gt;
&lt;h2&gt;1. 举个栗子&lt;/h2&gt;
&lt;p&gt;假设我们的 vocabulary 就只有6个单词,  &lt;code&gt;“a”,  “the”,  “red”,  “fox”,  “dog”,  and “.” &lt;/code&gt;. 模型需要从这里边预测输出句子 &lt;code&gt;W : &quot;a red fox .&quot;&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
P(W) &amp;amp; = P(w_1,  w_2,  \ldots,  w_n) \
&amp;amp; = P(w_n|w_1,  w_2,  \ldots,  w_{n-1}) \times P(w_1,  w_2,  \ldots,  w_{n-1})
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;对于这句话就是:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
P(&apos; a\ red\ fox\ . &apos;) =  P(&apos; a &apos;) \times P(&apos; red &apos; | &apos; a &apos;) \times P(&apos; fox &apos; | &apos; a\ red &apos;) \times P(&apos; . &apos;|&apos; a\ red\ fox &apos;)
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;假设模型,  预测第一个字的概率分布如下 :&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/IfNJ1tRBwbTH8lP.png&quot; alt=&quot;第1个字.png&quot; width=&quot;600&quot; height=&quot;400&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;则有
$P( &apos; a &apos; ) = 0.4$
,  进一步的
$P( w_2 | &apos; a &apos; )$
分布如下&lt;/p&gt;
&lt;p&gt;&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/vgHxO3nFumXrQAc.png&quot; alt=&quot;第2个字.png&quot; width=&quot;600&quot; height=&quot;400&quot; /&amp;gt;&lt;/p&gt;
&lt;p&gt;于是
$P( &apos; red &apos;  |  &apos; a &apos; ) = 0.27$
, 同理,  根据以下分布&lt;/p&gt;
&lt;p&gt;&amp;lt;div style=&quot;display: flex;&quot;&amp;gt;
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/UwFikWIL9tNPJRA.png&quot; alt=&quot;Image 1&quot; style=&quot;width: 100%;&quot;&amp;gt;
&amp;lt;img src=&quot;https://s2.loli.net/2024/04/10/k8DHmfxSJIuOTpY.png&quot; alt=&quot;Image 2&quot; style=&quot;width: 100%;&quot;&amp;gt;&lt;/p&gt;
&lt;p&gt;我们有如下结果:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
P(&apos; a\ red\ fox\ . &apos;) &amp;amp;=  P(&apos; a &apos;) \times P(&apos; red &apos; | &apos; a &apos;) \times P(&apos; fox &apos; | &apos; a\ red &apos;) \times P(&apos; . &apos;|&apos; a\ red\ fox &apos;)  \
&amp;amp;= 0.4 * 0.27 * 0.55 * 0.79 \
&amp;amp;= 0.0469
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;0.0469则表示当前这个模型对于预测 &quot;a red fox.&quot; 的信心如何, 不过有一个问题 : 因为这个信心是概率的连乘, 于是导致理论上, 句子越长, 信心越小. 因此需要进行一个 &quot;Normalize&quot;的操作 . 我们可以使用 &lt;a href=&quot;https://en.wikipedia.org/wiki/Geometric_mean&quot;&gt;几何平均数&lt;/a&gt; 来实现上述功能,  从而得到一个新的量化标准:&lt;/p&gt;
&lt;p&gt;$$
P_{norm}(W) = P(W)^{1/n}
$$&lt;/p&gt;
&lt;p&gt;这里的n表示句子的单词(token)数量.于是&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
P_{norm}(&apos;a\ red\ fox\ .&apos;) &amp;amp;= P(&apos;a\ red\ fox\ .&apos;)^{1/n} \
&amp;amp;= 0.0469 ^ {1/4} \
&amp;amp;= 0.465
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;这样我们就可以使用 $P_{norm}$ 来度量模型对不同长度句子的预测输出&quot;信心&quot;.&lt;/p&gt;
&lt;h2&gt;2. 如何计算&lt;/h2&gt;
&lt;p&gt;前边我们提到,  模型与输出的句子,  信心越足,  困惑度越小. 可以看到,  困惑度的计算公式如下:&lt;/p&gt;
&lt;p&gt;$$
\begin{align*}
PP(W) &amp;amp;= \frac {1} {P_{norm}(W)} \
&amp;amp;= \frac {1} {P(W)^{1/n}} \
&amp;amp;= (\frac {1} {P(W)}) ^{1/n} \
&amp;amp;= P(W) ^{-1/n}
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;对于之前的这个模型,  其 $PP(W) = (1/0.0469)^{1/n} ≈  2.15 $&lt;/p&gt;
&lt;p&gt;而假设有另外一个模型, 给定任意条件下, 对下一个单词的预测概率均相等为 1/6 . 那么这个模型的的困惑度为:&lt;/p&gt;
&lt;p&gt;$$
PP(W) = (\frac {1} {(1/6)^4}) ^{1/4} = 6
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;明显比之前的模型困惑度更高,  表明这个模型 更差 ,  因为这个模型就是随机输出.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h2&gt;3. 和交叉熵的关系&lt;/h2&gt;
&lt;p&gt;我们知道,  &lt;a href=&quot;https://zh.wikipedia.org/zh-hans/%E7%86%B5_(%E4%BF%A1%E6%81%AF%E8%AE%BA)&quot;&gt;香农熵&lt;/a&gt;  计算方式为 :&lt;/p&gt;
&lt;p&gt;$$
H(p) = -\sum_{i=1}^{n} p \log_{2} p
$$&lt;/p&gt;
&lt;p&gt;交叉熵的计算方式:&lt;/p&gt;
&lt;p&gt;$$
H(p, q) = -\sum_{i=1}^{n} p \log_{2} q
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;事实上:&lt;/p&gt;
&lt;p&gt;$$
KL(p, q) = -H(p) + H(p, q)
$$&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;对$PP(W)$进行拆解, 得以下式子:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;(1) 注意之前是 &lt;code&gt;n&lt;/code&gt; , 强调一个句子. 这里是 &lt;code&gt;N&lt;/code&gt; ,  强调模型对整个vocabulary的输出分布&lt;/p&gt;
&lt;p&gt;(2) 最后的 q分布 就是下一个单词的分布, 是一个 One-hot 向量&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
\begin{align*}
P(W) ^{-1/ N} &amp;amp;=  \prod_{i=1}^{N}   P(w)^{-  1/N} \
&amp;amp;=   P(w_1)^{-  1/N}  *  P(w_2)^{-  1/N}  * ... *  P(w_N)^{-  1/N}  \
&amp;amp;=   2 ^ { - \frac 1 N\ \sum_{i=1}^{N} \ log_2\ p } (忽略常数 2^{-1/N}) \
&amp;amp;= 2 ^ {\ H(P , \   q)}
\end{align*}
$$&lt;/p&gt;
&lt;p&gt;从这个角度来看,  困惑度越小,  交叉熵越小,  预测越准确.&lt;/p&gt;
&lt;p&gt;最后,  实际计算过程中,  可能使用以e为底的对数,  也有计算其log后作为困惑度,  此外还有一些其他计算方式,  但是本质类似, 就是想表达 &quot;预测输出的概率越大, 困惑度就越小&quot;&lt;/p&gt;
&lt;h2&gt;Reference&lt;/h2&gt;
&lt;p&gt;[1] &lt;a href=&quot;https://medium.com/nlplanet/two-minutes-nlp-perplexity-explained-with-simple-probabilities-6cdc46884584&quot;&gt;Two minutes NLP — Perplexity explained with simple probabilities&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[2] &lt;a href=&quot;https://en.wikipedia.org/wiki/Perplexity&quot;&gt;Wiki-Perplexity&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;[3] &lt;a href=&quot;https://webcache.googleusercontent.com/search?q=cache:https://towardsdatascience.com/perplexity-intuition-and-derivation-105dd481c8f3&amp;amp;strip=0&amp;amp;vwsrc=1&amp;amp;referer=medium-parser&quot;&gt;Perplexity Intuition (and its derivation)&lt;/a&gt;&lt;/p&gt;
</content:encoded></item></channel></rss>