最近一两年,大模型(LLM)如何突破训练长度限制(Length Extrapolation),支持更长文本的讨论非常热烈。 在众多方法中,基于RoPE (Rotary Positional Embedding) 的外推技术尤为流行,代表方法包括 Position Interpolation (PI)、NTK-Aware Interpolation、Dynamic NTK、NTK-by-parts 以及 YaRN 等。 这些方法各有千秋,但它们的数学原理和工程实现细节往往让人摸不着头脑。 本文将从数学推导和工程实现两个角度,系统地讲解这些 RoPE 外推方法,帮助读者彻底理解它们的本质区别和适用场景。
为了保证讲解的清晰和一致性,我们首先定义一套统一的数学符号,后续所有方法都基于这套符号进行推导。
基础符号定义 (Notation)
假设我们有一个 Query 或 Key 向量 ,维度为 。 RoPE (Rotary Positional Embedding) 将向量两两分组,在复平面上旋转。 对于第 组分量(其中 ),位置为 ,基础频率为 。
- 位置 (Position): (当前 token 的位置索引)。
- 维度索引: 。
- 基数 (Base): (LLaMA 等模型的默认值)。
- 频率 (Frequency): 。
- 旋转角度: 。
- 训练最大长度: (例如 2048 或 4096)。
- 推理/目标长度: 。
- 扩展倍率 (Scale): 。
原始 RoPE (Baseline)
在讲外推之前,必须先看原始 RoPE 做了什么。
数学实现: RoPE 的核心是对向量 应用一个旋转矩阵。对于第 对分量 ,旋转角度为 :
注意,这里的 公式为:
- 高频分量( 较小):旋转速度极快,关注局部位置信息。
- 低频分量( 较大):旋转速度极慢,关注长距离全局信息。
外推问题: 当 时, 会出现训练时从未见过的数值分布,导致注意力机制崩塌(PPL 爆炸)。
PPL 是什么?PPL (Perplexity):困惑度,是衡量语言模型预测能力的指标。PPL 越低,模型预测下一个词的能力越强。数学上,PPL 是交叉熵损失的指数形式:
其中 是词的数量, 是模型预测第 个词的概率。PPL可以理解为模型在预测时的“平均分支数”,PPL 越高,表示模型越不确定。
Position Interpolation (PI, 线性插值)
核心思想: 既然模型没见过大于 的 ,那我们把新的长位置 硬塞回 的范围内。也就是“把尺子刻度变密”。
数学实现: 将位置 替换为 。 新的频率项变为:
工程实现:
在代码中,不需要改变模型权重,只需要在预计算 cos 和 sin 表(Frequency Buffer)时修改即可。
# 伪代码scale = L_new / L_train # 例如 4096->8192, scale=2t = torch.arange(L_new)t = t / scale # 将 0~8191 压缩回 0~4095.5freqs = torch.outer(t, inv_freq)优缺点:
- 优点: 极其简单,微调(Fine-tune)几百步就能让模型适应长文本。
- 缺点: “高频分辨率丢失”。
- 对于高频部分( 很小),相邻 token 的旋转角度差值变小了(被除以了 ),导致模型无法精确区分相邻的词(例如 “cats” 和 “cat” 的区别)。这就是所谓的“分辨率(Resolution)分辨率”。
NTK-Aware Interpolation
核心思想: PI 简单粗暴地对所有频率都除以 ,这伤害了高频部分。 根据 Neural Tangent Kernel (NTK) 理论,高频特征学习快,低频特征学习慢。 我们不应该改变位置 (保持分辨率),而应该改变“基数”。 通过改变 ,我们可以让高频部分保持不变(不插值),只对低频部分进行插值。
数学实现: 我们要寻找一个新的基数 ,使得最低频()的效果等同于 PI 插值,而高频变化最小。 PI 的低频项是:。 NTK-Aware 保持 不变,改变 为 :。
令两者在最低频处相等( 最大时):
注:推导细节为了简化忽略了,实际公式考虑所有维度,推导出的修正基数如下:
最终公式:
通常 ,所以很多实现里直接近似 。
工程实现: 不用改 (位置索引),只改计算频率时的 base。
# 伪代码alpha = L_new / L_train# 计算新的 basenew_base = 10000 * (alpha ** (dim / (dim - 2)))# 重新计算 inv_freqinv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2) / dim))# 生成 cos/sin 表优缺点:
- 优点: 解决了 PI 在高频部分分辨率丢失的问题。无需微调也能获得比 PI 更好的 PPL,微调后效果更佳。
Dynamic NTK Interpolation
核心思想: 前面两种方法 是固定的。如果你设 ,但在推理时只输入了 1k 个词,模型性能反而会下降,因为它被不必要地“压缩”了。 Dynamic NTK 根据当前生成的序列长度 动态计算 。
数学实现: 在每一步推理(或者每隔一段):
- 获取当前序列长度 。
- 如果 ,使用原始 RoPE ()。
- 如果 ,计算 。
- 使用 NTK-Aware 重新计算 和频率表。
工程实现: 这对 KV Cache 系统提出了挑战,因为 RoPE 以前是静态的,现在变成了动态的。 如果 RoPE 是在 Attention 之前做的(通常是),意味着以前存进 KV Cache 的 Key 向量是基于旧的 旋转的。当 变化时,之前的 KV Cache 失效了?
- 实际上: 大多数实现为了性能,在推理时如果采用 Dynamic NTK,通常不重算 KV Cache,这会导致一定的误差,但在 RoPE 的语境下,模型往往能容忍这种渐变。或者采用“不旋转存 KV,取出时旋转”的策略(但计算量大)。
优缺点:
- 优点: 在短文本上保持原始性能,随着长度增加平滑过渡到外推模式。
NTK-by-parts Interpolation (分段 NTK)
核心思想: NTK-Aware 即使改变了 base,对高频部分还是有微小的影响。 “By-parts” 提出一个更激进的假设:波长(Wavelength,)。
- 如果波长 很短(远小于训练长度):这是高频,包含局部信息。绝对不要插值(保持原始 RoPE)。
- 如果波长 很长(大于训练长度):这是低频,包含长距离信息。必须插值(类似 PI)。
- 中间部分:过渡。
数学实现: 定义波长 。 定义两个阈值 (例如 )。 对于每个维度 :
- 如果 :不插值()。
- 如果 :全量插值()。
- 中间:线性混合。
由于直接对 混合比较麻烦,通常会将这个逻辑转换为改变频率 :
其中 是根据波长计算出的 ramp 函数(0 到 1)。
工程实现: 需要为每个维度 计算一个缩放因子。
# CodeLLaMA 采用的就是类似思路# 计算每个维度的波长wavelen = 2 * math.pi / inv_freq# 这种方法不改 base,而是直接改 freq# 对 wavelen 小的不变,大的除以 scaleYaRN (Yet another RoPE extensioN)
这是目前的集大成者(CodeLlama 等模型深受其影响)。
核心思想: YaRN 结合了 NTK-by-parts 的分段思想,并解决了一个隐藏问题:熵(Entropy)的变化。 当我们将所有 token 的 RoPE 距离拉开(或压缩)时,Attention 矩阵中的点积数值分布会发生变化(通常变得更“平”或更“尖”),这会导致 Softmax 后的分布熵变大,模型困惑度增加。 YaRN 引入了 temperature scaling 来修正注意力分布。
数学实现(三大步骤):
- NTK-by-parts: 使用分段策略(高频不转,低频转,中间插值)。
- 不像原始 NTK 改 base,YaRN 直接通过修正 frequency 来实现。
- Length Scaling (注意力修正): 在计算 Attention Score 时,乘上一个额外的因子 。 或者简单地理解为,在 RoPE 之后,把 向量本身乘以 (即 )。 因子 通常取值: 如果不加这个修正,长文本外推时,注意力分布会因为上下文变长而变得过于“平坦”,模型不知道该关注谁。
工程实现: YaRN 的实现通常包含两个部分:
Get_Rotary_Embedding: 生成修改后的 cos/sin 表(基于分段逻辑)。Model Forward: 在应用 FlashAttention 之前或之后,对 Logits 进行 scaling,或者直接对 RoPE 后的 output 进行 scaling。
# YaRN 简易逻辑# 1. 计算 ramp function (ramp) 基于波长# 2. inv_freq_new = inv_freq * (1 - ramp) + (inv_freq / scale) * ramp# 3. 还有一个 scale factorsqrt_t = 0.1 * math.log(scale) + 1# 在应用 RoPE 后,将 q, k 乘以 sqrt(sqrt_t) (或者在 attention socre 乘 1/t)q = apply_rope(q, inv_freq_new) * t_correctionk = apply_rope(k, inv_freq_new) * t_correction总结与对比 (Cheat Sheet)
| 方法 | 核心手段 | (位置) | (频率/基数) | (向量值) | 评价 |
|---|---|---|---|---|---|
| RoPE | 旋转 | 不变 | 默认 | 不变 | 基础,无法直接外推 |
| PI | 线性压缩位置 | 不变 | 不变 | 简单,但高频分辨率差,微调必须 | |
| NTK-Aware | 拉伸基数 | 不变 | 不变 | 数学优雅,高频无损,微调效果好 | |
| Dynamic NTK | 动态拉伸 | 不变 | 动态变化 | 不变 | 推理时自适应长度,无需微调即可用 |
| NTK-by-parts | 分频段处理 | 混合 | 高频不变,低频插值 | 不变 | CodeLlama 基础,保留局部能力 |
| YaRN | 分段 + 修改熵 | 混合 | 同上 | 乘以 | 目前最强,最稳定的长文本方案 |
部分信息可能已经过时