©PaperWeekly 原创 · 作者 | 苏剑林
我们知道,在 RoPE 中频率的计算公式为 ,底数 默认值为 10000。目前 Long Context 的主流做法之一是,先在 上用短文本预训练,然后调大 并在长文本微调,其出发点是《Transformer升级之路:RoPE是一种β进制编码》里介绍的 NTK-RoPE,它本身有较好长度外推性,换用更大的 再微调相比不加改动的微调,起始损失更小,收敛也更快。该过程给人的感觉是:调大 完全是因为“先短后长”的训练策略,如果一直都用长文本训练似乎就没必要调大 了?近期的论文《Base of RoPE Bounds Context Length》[1] 试图回答这个问题,它基于一个期望性质研究了 的下界,由此指出更大的训练长度本身就应该选择更大的底数,与训练策略无关。整个分析思路颇有启发性,接下来我们一起来品鉴一番。
RoPE 这里就不再详细介绍了,它本质上是一个分块对角矩阵给 注入绝对位置信息,并自动实现了相对位置的效果。其中 ,这里的 的取值就是本文要探讨的问题。除了给模型注入位置信息外,我们期望 RoPE 能具备两个理想性质,以达到更好的效果:1)远程衰减,即位置相近的 Token 平均来说获得更多的注意力;2)语义聚合,即语义相似的 Token 平均来说获得更多的注意力。其中第一点我们早在《Transformer升级之路:博采众长的旋转式位置编码》有过相关讨论,RoPE 确实有一定的远程衰减性质。所以接下来我们来分析第二点。
不等关系
所谓语义聚合,指的是当 相近时,不管它们的相对距离 多大,其注意力 平均来说都应该更大(至少要比随机的两个 Token 更大)。为了得到一个量化的结论,我们进一步简化问题,假设 的每个分量都是独立同分布的,每个分量的均值为 ,方差为 。现在我们考虑两种不同的 :一种是在 的基础上,加上一个零均值的扰动 ,我们记 ,代表跟 语义相近的 Token;另一种则是假设 跟 独立同分布,这代表两个随机的 Token。根据第二点理想性质,我们希望有注意我们刚才反复强调了“平均来说”,意味着我们只是期望一个平均的趋势,而不是每一点都能严格成立,所以我们在上式加了取数学期望 。现在根据假设和 RoPE 的定义,我们可以把上式具体地算出来:如果训练长度最大为 L,那么 ,因此第二点理想性质可以用如下不等式近似描述:其中 L 是最大长度,是训练前就要选定的超参,而 是模型的 head_size,按照 LLAMA 的一般设置是 ,这也就意味着,上式的唯一可调参数就是 中的 。在《Transformer升级之路:Sinusoidal位置编码追根溯源》中我们就简单探究过这个函数,它整体趋势是衰减的, 越大则衰减速度越慢,对应的连续非负区间就越大,所以存在一个最小的 b 使得上述不等式恒成立,即
由于 涉及到多个三角函数的求和,并且 关于 还是非线性的,很难想象上述问题会有解析解,因此只能诉诸数值求解了。然而, 越到后面震荡越频繁且不规律,因此即便数值求解也不是那么简单的事情。笔者一开始以为,如果 使得 恒成立,那么 都恒成立 ,所以用二分法就可以了。但事实上这个假设并不成立,所以二分法宣告破产。继续想了一段时间,依然没什么优化思路,期间向原论文作者请教过,他们采用的是逆函数法,即给定 求使得 恒成立的最大 L 是比较简单的,于是我们可以得到很多 对,理论上只要枚举的 足够多,那么对于任意 都可以找出最小的 。然而这里有个精度问题,原论文最大的 计算到了 , 至少要枚举到 ,如果枚举间隔小,那么计算成本非常大,如果枚举间隔大,那么可能漏掉很多解。最后,笔者决定还是用 “Jax + GPU” 进行暴力搜索,以求得到更高精度的结果,大致流程是:2.1 将 等分为 份,遍历等分点,判断 是否恒成立; 1from functools import partial
2import numpy as np
3import jax.numpy as jnp
4import jax
5
6@partial(jax.jit, static_argnums=(2,))
7def f(m, b, d=128):
8 i = jnp.arange(d / 2)
9 return jnp.cos(m[:, None] * b ** (-2 * i[None] / d)).sum(axis=1)
10
11@np.vectorize
12def fmin(L, b):
13 return f(np.arange(L), b).min()
14
15def bmin(L):
16 B = 1000 * L
17 for k in range(1, 6):
18 bs = np.linspace(0, 1, 10**k + 1)[1:] * B
19 ys = fmin(L, bs)
20 for b, y in zip(bs, ys):
21 if y >= 0:
22 B = b
23 break
24 return B
25
26bmin(1024 * 128)
除了数值求解外,我们也可以通过渐近分析来得到一个解析的估计结果,这个估计比数值结果要小,本质上是 的解,但同样能够得出“ 应该随着 增大而增大”的结论。这是被前人研究过的三角积分(参考 Trigonometric integral [2]),利用这个记号,我们可以写出它的第一个零点是 ,对于 ,可以看出 ,所以其实 相对来说是小项,对于渐近估计来说可以忽略,那么问题近似地变成了 对于 恒成立,我们只需要让相应的 都落在 区间内就可以实现,这意味着 ,即或者简单点 。不出意料这个结果比精确的数值结果要小,因为它对应于 ,无限个三角函数叠加会使得函数图像的震荡更少,看起来更加平稳(相比于有限的 ),从而对于固定的 的连续非负区间更长,或者反过来,对于固定的 ,保持 的 都非负的 更小。另一方面,Meta 最新发布的 LLAMA3,训练长度为 8192,但 RoPE 的底数选择了惊人的 500000(5e5),这比前面的数值结果(8.4e4)还要大将近一个数量级,不管从哪个角度看,这个数值笔者都认为是偏大的,可能 LLAMA3 的这个底数本就是给更大文本长度预留的。但不论如何,更大的文本长度选择更大的 RoPE 底数,似乎已经成为了很多训练人员的共识。其实不管是数值结果还是渐近估计,都只是一个参考值,实际上对于给定的L,一个相当大范围内的 b 都应该会有相近的效果。所以具体的数值都不重要,关键是原论文通过语义聚合的出发点和一系列推导,澄清了“ 应该随着 增大而增大”的结论及其原理,这是笔者所认为的原论文的核心贡献。此外,其实语义聚合的出发点和结论也可以用来解释 Position Interpolation [3](PI)。刚才我们说了,同一个 的连续非负区间是固定的,如果要使 都落在非负区间内,就需要随着 的增大而相应的增加 。但反过来,我们也可以不增加 b,而是减少相邻位置的间隔(即位置ID改成 ),那么就可以在同样大小的非负区间内表示 k 倍的位置了,这便是语义聚合视角下的 Position Interpolation。本文简单介绍了论文《Base of RoPE Bounds Context Length》 [1],它从语义聚合的期望性质讨论了 RoPE 的底数下界,由此指出更大的训练长度应该选择更大的底数,而不单单是为了配合“先短后长”的训练策略、继而利用 NTK-RoPE 来降低初始损失的折中选择。
[1] https://papers.cool/arxiv/2405.14591[2] https://en.wikipedia.org/wiki/Trigonometric_integral[3] https://papers.cool/arxiv/2306.15595
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧