LOADING
2469 字
12 分钟
Minimind 手动实操 1: RoPE 长度外推全解析

最近一两年,大模型(LLM)如何突破训练长度限制(Length Extrapolation),支持更长文本的讨论非常热烈。 在众多方法中,基于RoPE (Rotary Positional Embedding) 的外推技术尤为流行,代表方法包括 Position Interpolation (PI)、NTK-Aware Interpolation、Dynamic NTK、NTK-by-parts 以及 YaRN 等。 这些方法各有千秋,但它们的数学原理和工程实现细节往往让人摸不着头脑。 本文将从数学推导和工程实现两个角度,系统地讲解这些 RoPE 外推方法,帮助读者彻底理解它们的本质区别和适用场景。

为了保证讲解的清晰和一致性,我们首先定义一套统一的数学符号,后续所有方法都基于这套符号进行推导。

基础符号定义 (Notation)#

假设我们有一个 Query 或 Key 向量 xx,维度为 dd。 RoPE (Rotary Positional Embedding) 将向量两两分组,在复平面上旋转。 对于第 ii 组分量(其中 i[0,d/2)i \in [0, d/2)),位置为 mm,基础频率为 θi\theta_i

  • 位置 (Position): mm (当前 token 的位置索引)。
  • 维度索引: i[0,1,...,d/21]i \in [0, 1, ..., d/2 - 1]
  • 基数 (Base): b=10000b = 10000 (LLaMA 等模型的默认值)。
  • 频率 (Frequency): θi=b2i/d\theta_i = b^{-2i/d}
  • 旋转角度: mθim \cdot \theta_i
  • 训练最大长度: LtrainL_{train} (例如 2048 或 4096)。
  • 推理/目标长度: LnewL_{new}
  • 扩展倍率 (Scale): s=LnewLtrains = \frac{L_{new}}{L_{train}}

原始 RoPE (Baseline)#

在讲外推之前,必须先看原始 RoPE 做了什么。

数学实现: RoPE 的核心是对向量 xx 应用一个旋转矩阵。对于第 ii 对分量 (x2i,x2i+1)(x_{2i}, x_{2i+1}),旋转角度为 mθim \theta_i

f(x,m)i=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))(x2ix2i+1)f(x, m)_i = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix} \begin{pmatrix} x_{2i} \\ x_{2i+1} \end{pmatrix}

注意,这里的 θi\theta_i 公式为:

θi=100002i/d\theta_i = 10000^{-2i/d}
  • 高频分量(ii 较小):旋转速度极快,关注局部位置信息。
  • 低频分量(ii 较大):旋转速度极慢,关注长距离全局信息。

外推问题:m>Ltrainm > L_{train} 时,cos(mθi)\cos(m\theta_i) 会出现训练时从未见过的数值分布,导致注意力机制崩塌(PPL 爆炸)。

PPL 是什么?

PPL (Perplexity):困惑度,是衡量语言模型预测能力的指标。PPL 越低,模型预测下一个词的能力越强。数学上,PPL 是交叉熵损失的指数形式:

PPL=exp(1Ni=1NlogP(wiw<i))\text{PPL} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_{<i})\right)

其中 NN 是词的数量,P(wiw<i)P(w_i | w_{<i}) 是模型预测第 ii 个词的概率。PPL可以理解为模型在预测时的“平均分支数”,PPL 越高,表示模型越不确定。


Position Interpolation (PI, 线性插值)#

核心思想: 既然模型没见过大于 LtrainL_{train}mm,那我们把新的长位置 LnewL_{new} 硬塞回 [0,Ltrain][0, L_{train}] 的范围内。也就是“把尺子刻度变密”。

数学实现: 将位置 mm 替换为 m=msm' = \frac{m}{s}。 新的频率项变为:

Angle=msθi\text{Angle} = \frac{m}{s} \theta_i

工程实现: 在代码中,不需要改变模型权重,只需要在预计算 cossin 表(Frequency Buffer)时修改即可。

# 伪代码
scale = L_new / L_train # 例如 4096->8192, scale=2
t = torch.arange(L_new)
t = t / scale # 将 0~8191 压缩回 0~4095.5
freqs = torch.outer(t, inv_freq)

优缺点:

  • 优点: 极其简单,微调(Fine-tune)几百步就能让模型适应长文本。
  • 缺点: “高频分辨率丢失”
    • 对于高频部分(ii 很小),相邻 token 的旋转角度差值变小了(被除以了 ss),导致模型无法精确区分相邻的词(例如 “cats” 和 “cat” 的区别)。这就是所谓的“分辨率(Resolution)分辨率”。

NTK-Aware Interpolation#

核心思想: PI 简单粗暴地对所有频率都除以 ss,这伤害了高频部分。 根据 Neural Tangent Kernel (NTK) 理论,高频特征学习快,低频特征学习慢。 我们不应该改变位置 mm(保持分辨率),而应该改变“基数”bb 通过改变 bb,我们可以让高频部分保持不变(不插值),只对低频部分进行插值。

数学实现: 我们要寻找一个新的基数 bb',使得最低频(i=d/21i=d/2-1)的效果等同于 PI 插值,而高频变化最小。 PI 的低频项是:msb2i/d\frac{m}{s} \cdot b^{-2i/d}。 NTK-Aware 保持 mm 不变,改变 bbbb'm(b)2i/dm \cdot (b')^{-2i/d}

令两者在最低频处相等(ii 最大时):

1sb1=(b)1    b=bs\frac{1}{s} b^{-1} = (b')^{-1} \implies b' = b \cdot s

注:推导细节为了简化忽略了2i/d2i/d,实际公式考虑所有维度,推导出的修正基数如下:

最终公式:

b=bsdd2b' = b \cdot s^{\frac{d}{d-2}}

通常 dd21\frac{d}{d-2} \approx 1,所以很多实现里直接近似 bbsb' \approx b \cdot s

工程实现: 不用改 mm(位置索引),只改计算频率时的 base。

# 伪代码
alpha = L_new / L_train
# 计算新的 base
new_base = 10000 * (alpha ** (dim / (dim - 2)))
# 重新计算 inv_freq
inv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2) / dim))
# 生成 cos/sin 表

优缺点:

  • 优点: 解决了 PI 在高频部分分辨率丢失的问题。无需微调也能获得比 PI 更好的 PPL,微调后效果更佳。

Dynamic NTK Interpolation#

核心思想: 前面两种方法 ss 是固定的。如果你设 Lnew=16kL_{new}=16k,但在推理时只输入了 1k 个词,模型性能反而会下降,因为它被不必要地“压缩”了。 Dynamic NTK 根据当前生成的序列长度 LcurrL_{curr} 动态计算 ss

数学实现: 在每一步推理(或者每隔一段):

  1. 获取当前序列长度 LcurrL_{curr}
  2. 如果 LcurrLtrainL_{curr} \le L_{train},使用原始 RoPE (scale=1scale=1)。
  3. 如果 Lcurr>LtrainL_{curr} > L_{train},计算 s=LcurrLtrains = \frac{L_{curr}}{L_{train}}
  4. 使用 NTK-Aware 重新计算 bb' 和频率表。

工程实现: 这对 KV Cache 系统提出了挑战,因为 RoPE 以前是静态的,现在变成了动态的。 如果 RoPE 是在 Attention 之前做的(通常是),意味着以前存进 KV Cache 的 Key 向量是基于旧的 ss 旋转的。当 ss 变化时,之前的 KV Cache 失效了

  • 实际上: 大多数实现为了性能,在推理时如果采用 Dynamic NTK,通常不重算 KV Cache,这会导致一定的误差,但在 RoPE 的语境下,模型往往能容忍这种渐变。或者采用“不旋转存 KV,取出时旋转”的策略(但计算量大)。

优缺点:

  • 优点: 在短文本上保持原始性能,随着长度增加平滑过渡到外推模式。

NTK-by-parts Interpolation (分段 NTK)#

核心思想: NTK-Aware 即使改变了 base,对高频部分还是有微小的影响。 “By-parts” 提出一个更激进的假设:波长(Wavelength,λ=2π/θi\lambda = 2\pi / \theta_i)。

  1. 如果波长 λ\lambda 很短(远小于训练长度):这是高频,包含局部信息。绝对不要插值(保持原始 RoPE)。
  2. 如果波长 λ\lambda 很长(大于训练长度):这是低频,包含长距离信息。必须插值(类似 PI)。
  3. 中间部分:过渡。

数学实现: 定义波长 λi=2πb2i/d\lambda_i = 2\pi \cdot b^{2i/d}。 定义两个阈值 α,β\alpha, \beta(例如 α=1,β=32\alpha=1, \beta=32)。 对于每个维度 ii

  • 如果 λi/Ltrain<α\lambda_i / L_{train} < \alpha:不插值(mmm \to m)。
  • 如果 λi/Ltrain>β\lambda_i / L_{train} > \beta:全量插值(mm/sm \to m/s)。
  • 中间:线性混合。

由于直接对 mm 混合比较麻烦,通常会将这个逻辑转换为改变频率 θi\theta_i

θi=(1γi)θi+γiθis\theta_i' = (1 - \gamma_i) \theta_i + \gamma_i \frac{\theta_i}{s}

其中 γi\gamma_i 是根据波长计算出的 ramp 函数(0 到 1)。

工程实现: 需要为每个维度 ii 计算一个缩放因子。

# CodeLLaMA 采用的就是类似思路
# 计算每个维度的波长
wavelen = 2 * math.pi / inv_freq
# 这种方法不改 base,而是直接改 freq
# 对 wavelen 小的不变,大的除以 scale

YaRN (Yet another RoPE extensioN)#

这是目前的集大成者(CodeLlama 等模型深受其影响)。

核心思想: YaRN 结合了 NTK-by-parts 的分段思想,并解决了一个隐藏问题:熵(Entropy)的变化。 当我们将所有 token 的 RoPE 距离拉开(或压缩)时,Attention 矩阵中的点积数值分布会发生变化(通常变得更“平”或更“尖”),这会导致 Softmax 后的分布熵变大,模型困惑度增加。 YaRN 引入了 temperature scaling 来修正注意力分布。

数学实现(三大步骤):

  1. NTK-by-parts: 使用分段策略(高频不转,低频转,中间插值)。
    • 不像原始 NTK 改 base,YaRN 直接通过修正 frequency 来实现。
  2. Length Scaling (注意力修正): 在计算 Attention Score qkTd\frac{q \cdot k^T}{\sqrt{d}} 时,乘上一个额外的因子 t\sqrt{t}。 或者简单地理解为,在 RoPE 之后,把 q,kq, k 向量本身乘以 t\sqrt{\sqrt{t}}(即 t0.25t^{0.25})。 因子 tt 通常取值: t0.1ln(s)+1\sqrt{t} \approx 0.1 \ln(s) + 1 如果不加这个修正,长文本外推时,注意力分布会因为上下文变长而变得过于“平坦”,模型不知道该关注谁。

工程实现: YaRN 的实现通常包含两个部分:

  1. Get_Rotary_Embedding: 生成修改后的 cos/sin 表(基于分段逻辑)。
  2. Model Forward: 在应用 FlashAttention 之前或之后,对 Logits 进行 scaling,或者直接对 RoPE 后的 output 进行 scaling。
# YaRN 简易逻辑
# 1. 计算 ramp function (ramp) 基于波长
# 2. inv_freq_new = inv_freq * (1 - ramp) + (inv_freq / scale) * ramp
# 3. 还有一个 scale factor
sqrt_t = 0.1 * math.log(scale) + 1
# 在应用 RoPE 后,将 q, k 乘以 sqrt(sqrt_t) (或者在 attention socre 乘 1/t)
q = apply_rope(q, inv_freq_new) * t_correction
k = apply_rope(k, inv_freq_new) * t_correction

总结与对比 (Cheat Sheet)#

方法核心手段mm (位置)θ\theta (频率/基数)xx (向量值)评价
RoPE旋转不变默认 b=10000b=10000不变基础,无法直接外推
PI线性压缩位置mm/sm \to m/s不变不变简单,但高频分辨率差,微调必须
NTK-Aware拉伸基数不变bbsdd2b \to b \cdot s^\frac{d}{d-2}不变数学优雅,高频无损,微调效果好
Dynamic NTK动态拉伸不变bb 动态变化不变推理时自适应长度,无需微调即可用
NTK-by-parts分频段处理混合高频不变,低频插值不变CodeLlama 基础,保留局部能力
YaRN分段 + 修改熵混合同上乘以 t\sqrt{t}目前最强,最稳定的长文本方案
Minimind 手动实操 1: RoPE 长度外推全解析
/posts/minimind动手实操1_rope/
作者
Olynx
发布于
2025-12-23
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时

Profile Image of the Author
Olynx
在世界的正文与脚注之间,执笔第三种可能。

统计加载中...
译者手记
有些章节需要重读,有些代码需要重写。理解的本质,或许就是不断为世界添加新的脚注。