从标准位置编码、复数、欧拉公式到旋转位置编码RoPE

openEuler 社区

一文通透位置编码:从标准位置编码、复数、欧拉公式到旋转位置编码RoPE(含其推导与代码实现)

前言

关于位置编码和RoPE

  1. 应用广泛,是很多大模型使用的一种位置编码方式,包括且不限于LLaMA、baichuan、ChatGLM等等
  2. 我之前在本博客中的另外两篇文章中有阐述过(一篇是关于LLaMA解读的,一篇是关于transformer从零实现的),但自觉写的不是特别透彻好懂
    再后来在我参与主讲的类ChatGPT微调实战课中也有讲过,但有些学员依然反馈RoPE不是特别好理解

考虑到只要花足够多的时间 心思 投入,没有写不清楚的,讲课更是如此,故为彻底解决这个位置编码/RoPE的问题,我把另外两篇文章中关于位置编码的内容抽取出来,并不断深入、扩展、深入,比如其中最关键的改进是两轮改进,一个12.16那天,一个12.21那天

  • (应该是24年) 12.16那天
    小的改进是把“ 1.1 标准位置编码的起源 ”中,关于i、2i、2i+1的一系列计算结果用表格规整了下
    如此,相比之前把一堆数字一堆,表格更加清晰、一目了然
    大的改进是把“ 3.1.1 第一种形式的推导(通俗易懂版) ”的细节重新梳理了以下,以更加一目了然、一看即懂,可能是全网关于RoPE最通俗细致的推导
  • (也应该是24年) 12.21那天
    把RoPE的本质给强调出来

最终成为本文

大家在阅读本文的过程中,不要小看本文前两部分的前置知识与铺垫,正如本文一读者danlgag所评论的:“感谢前文的知识铺垫,基本上能看懂rope的推导了”

  1. 对此,我自己在学每个技术/模型/论文时,少部分可能一看即懂,但大部分都非一看即懂,有的甚至经历过痛苦的挣扎,反复琢磨、思考到底是咋回事
    正因为经历过痛苦、挣扎,所以我深知后面的后学者,其中不少可能也得经历类似的痛苦、挣扎
  2. 可如果我把我的来时路,即 分享「我当初是怎么从不懂到懂,以及如何一步一步搞懂」的过程的话,那便可以大大减少很多人的痛苦/挣扎,加速很多人的理解速度、与理解深度
    因我一人之痛苦/挣扎,免去万千人之痛苦/挣扎,实乃我之荣幸、功德

第一部分 transformer原始论文中的标准位置编码

如此篇文章《 Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT 》所述,RNN的结构包含了序列的时序信息,而Transformer却完全把时序信息给丢掉了,比如“他欠我100万”,和“我欠他100万”,两者的意思千差万别,故为了解决时序的问题,Transformer的作者用了一个绝妙的办法:位置编码(Positional Encoding)

1.1 标准位置编码的起源

即将每个位置编号,从而每个编号对应一个向量,最终通过结合位置向量和词向量,作为输入embedding,就给每个词都引入了一定的位置信息,这样Attention就可以分辨出不同位置的词了,具体怎么做呢?

  1. 如果简单粗暴的话,直接给每个向量分配一个数字,比如1到1000之间
  2. 也可以用one-hot编码表示位置
  3. transformer论文中作者通过sin函数和cos函数交替来创建 positional encoding,其计算positional encoding的公式如下PE_{(pos,2i+1)} = cos\left ( \frac{pos}{10000^{\frac{2i}{d_{model}}}} \right ) \end{aligned}PE_{(pos,2i)} = sin\left ( \frac{pos}{10000^{\frac{2i}{d_{model}}}} \right )至于是embedding向量的位置下标对2求商并取整(可用双斜杠表示整数除法,即求商并取整),它的取值范围是 [0,...,dmodel2] ,比如

    位置向量的第多少维 (0 2 4等偶数维用sin函数计算)
    0
    1
    2
    3
    4
    5
    6
    ….
    510
    511
    相当于
    是指向量维度中的偶数维,即第0维、第2维、第4维…,第510维,用sin函数计算
    是向量维度中的奇数维,即第1维、第3维、第5维..,第511维,用cos函数计算

不要小看transformer的这个位置编码,不少做NLP多年的人也不一定对其中的细节有多深入,而网上大部分文章谈到这个位置编码时基本都是千篇一律、泛泛而谈,很少有深入,故本文还是细致探讨下

1.2 标准位置编码的示例:多图多举例

考虑到一图胜千言 一例胜万语,举个例子,当我们要编码「我 爱 你」的位置向量,假定每个token都具备512维,如果位置下标从0开始时,则根据位置编码的计算公式可得 且为让每个读者阅读本文时一目了然,我计算了 每个单词对应的位置编码示例 (在此之前,这些示例在其他地方基本没有)

  • 当对上的单词「我」进行位置编码时,它本身的维度有512维
    PE0=[sin(0100000512),cos(0100000512),sin(0100002512),cos(0100002512),sin(0100004512),cos(0100004512),...,sin(010000510512),cos(010000510512)]
  • 当对上的单词「爱」进行位置编码时,它本身的维度有512维
  • 当对上的单词「你」进行位置编码时,它本身的维度有512维
    PE2=[sin(2100000512),cos(2100000512),sin(2100002512),cos(2100002512),sin(2100004512),cos(2100004512),...,sin(210000510512),cos(210000510512)]
  • ….

最终得到的可视化效果如下图所示

1.3 标准位置编码的coding实现

代码实现如下

“”“位置编码的实现,调用父类nn.Module的构造函数”“”
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()  
        self.dropout = nn.Dropout(p=dropout)  # 初始化dropout层
        
        # 计算位置编码并将其存储在pe张量中
        pe = torch.zeros(max_len, d_model)                # 创建一个max_len x d_model的全零张量
        position = torch.arange(0, max_len).unsqueeze(1)  # 生成0到max_len-1的整数序列,并添加一个维度
        # 计算div_term,用于缩放不同位置的正弦和余弦函数
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))

        # 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)                  # 在第一个维度添加一个维度,以便进行批处理
        self.register_buffer('pe', pe)        # 将位置编码张量注册为缓冲区,以便在不同设备之间传输模型时保持其状态
        
    # 定义前向传播函数
    def forward(self, x):
        # 将输入x与对应的位置编码相加
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        # 应用dropout层并返回结果
        return self.dropout(x)

本文发布之后,有同学留言问,上面中的第11行、12行代码

div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

为什么先转换为了等价的指数+对数运算,而不是直接幂运算?是效率、精度方面有差异吗?

这里使用指数和对数运算的原因是为了确保数值稳定性和计算效率

  • 一方面,直接使用幂运算可能会导致数值上溢或下溢。当d_model较大时,10000.0 ** (-i / d_model)中的幂可能会变得非常小,以至于在数值计算中产生下溢。通过将其转换为指数和对数运算,可以避免这种情况,因为这样可以在计算过程中保持更好的数值范围
  • 二方面,在许多计算设备和库中,指数和对数运算的实现通常比幂运算更快。这主要是因为指数和对数运算在底层硬件和软件中有特定的优化实现,而幂运算通常需要计算更多的中间值

所以,使用指数和对数运算可以在保持数值稳定性的同时提高计算效率。

既然提到了这行代码,我们干脆就再讲更细致些,上面那行代码对应的公式为

其中的中括号对应的是一个从 0 到 dmodel1 的等差数列(步长为 2),设为

且上述公式与这个公式是等价的

为何,原因在于 ax=e(xln(a)) ,从而有 10000idmodel=e(idmodellog(10000))

最终,再通过下面这两行代码完美实现位置编码

# 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

第二部分 从复数到欧拉公式

先复习下复数的一些关键概念

  1. 我们一般用表示 复数 ,实数 叫做复数的实部,实数 叫做复数的虚部
  2. 复数的辐角是指复数在复平面上对应的向量和正向实数轴所成的有向角
  3. 的共轭复数定义为:,也可记作 z¯ ,复数与其共轭的乘积等于它的模的平方,即 z \times z^* = a^2 + b^2 = |z|^2 ,这是一个实数

2.1 如何通俗易懂的理解复数

在我们的日常生活中,经常会遇到各种平移运动,为了描述这些平移运动,数学上定义了加减乘除,然还有一类运动是旋转运动,而加减乘除无法去描述旋转运动,而有了复数之后,便不一样了,此话怎讲?

根据复数的定义: i=1 ,可以看出来: i2=1×i×i=1 ,而这个展开过程就揭示了虚数 背后的本质,因为这个展开过程中的两次乘法可以看成连续的操作

  • 即把 1 经过2次完全一样的操作: ×i ,变成了 −1 ,那什么样的操作能得到这个效果呢?
  • 你两眼一亮,直呼:旋转啊,先旋转 90度,再旋转 90 度就可以了啊,如下图所示

so, 就代表了旋转(至此,可能你已经隐隐约约意识到,为何我们在解释旋转位置编码时,为何要扯上复数了),为形象说明,再举两个例子

  • 比如对于 e^{i \pi}+1=0 ,自然数 1,绕坐标中心旋转180度(eiπ),再平移1 ,就回到坐标原点
  • 再比如对于

2.2 如何快速理解欧拉公式

2.2.1 什么是欧拉公式

当 表示任意实数, 是自然对数的底数, 是复数中的虚数单位,则根据欧拉公式有

eix=cosxplus;isinx

表达的含义在于该指数函数可以表示为实部为,虚部为的一个复数

该欧拉公式相当于建立了指数函数、三角函数和复数之间的桥梁,但怎么推导出来的呢,其实很简单

  1. 由于有

ex=1plus;xplus;12!x2plus;13!x3plus;sin(x)=x13!x3plus;15!x5plus;cos(x)=112!x2plus;14!x4plus;

  1. 所以,如果 x=iθ ,则有

\begin{aligned} e^{i \theta} & =1+i \theta+\frac{(i \theta)^{2}}{2 !}+\frac{(i \theta)^{3}}{3 !}+\frac{(i \theta)^{4}}{4 !}+\frac{(i \theta)^{5}}{5 !}+\frac{(i \theta)^{6}}{6 !}+\frac{(i \theta)^{7}}{7 !}+\frac{(i \theta)^{8}}{8 !}+\cdots \\ & =1+i \theta-\frac{\theta^{2}}{2 !}-\frac{i \theta^{3}}{3 !}+\frac{\theta^{4}}{4 !}+\frac{i \theta^{5}}{5 !}-\frac{\theta^{6}}{6 !}-\frac{i \theta^{7}}{7 !}+\frac{\theta^{8}}{8 !}+\cdots \\ & =\left(1-\frac{\theta^{2}}{2 !}+\frac{\theta^{4}}{4 !}-\frac{\theta^{6}}{6 !}+\frac{\theta^{8}}{8 !}-\cdots\right)+i\left(\theta-\frac{\theta^{3}}{3 !}+\frac{\theta^{5}}{5 !}-\frac{\theta^{7}}{7 !}+\cdots\right) \\ & =\cos \theta+i \sin \theta \end{aligned}

2.2.2 欧拉公式与三角函数

如何直观的理解这个欧拉公式呢?

其实,可以把 eiθ 看作通过单位圆的圆周运动来描述单位圆上的点, \cos \theta+i \sin \theta 通过复平面的坐标来描述单位圆上的点,是同一个点不同的描述方式,所以有 e^{i \theta}=\cos \theta+i \sin \theta ,如下图所示

根据欧拉公式 e^{i \theta}=\cos \theta+i \sin \theta ,可以轻易推出:

sinθ=eiθeiθ2i\begin{aligned} \cos \theta=\frac{e^{i \theta}+e^{-i \theta}}{2} \end{aligned}

我们把复数当作向量来看待,复数的实部是方向,虚部是方向,很容易观察出其几何意义,如下图所示

还在思考怎么得来的?很简单哦,还记得向量的加减法么?

第三部分 旋转位置编码(RoPE)的推导与实现

3.1 旋转位置编码的原理与推导

所谓旋转位置编码,其在位置编码上删除了绝对位置嵌入,而在网络的每一层增加了苏剑林等人(2021)提出的 旋转位置嵌入(RoPE) ,其思想是采用绝对位置编码的形式 实现相对位置编码,且RoPE主要借助了复数的思想

具体来说,当咱们给self-attention中的向量都加入了位置信息后,便可以表示为

qm=fq(xm,m)kn=fk(xn,n)vn=fv(xn,n)

其中

  • qm 表示「第 个 token 对应的词向量 」集成「位置信息 」之后的 query 向量
  • 而 、 则分别表示 第 个 token 对应的词向量 集成 位置信息 之后 的 key 向量、 value 向量

3.1.1 第一种形式的推导(可能是全网最通俗易懂版)

接着论文中提出为了能利用上 token 之间的相对位置信息,假定 query 向量 和 key 向量 之间的内积操作可以被一个函数 表示,该函数 的输入是词嵌入向量 、 ,和它们之间的相对位置 :

<fq(xm,m),fk(xn,n)>=g(xm,xn,mn)

这里面其实有很大的一个关键,但大部分资料甚至RoPE原始论文都不会给你特别强调出来,即为何要构造这么一个等式呢?

  • 原因在于左边算是q和k向量的内积,而这恰好是transformer计算自注意力机制的核心一步,右边等式则意味着m与n的相对位置
    如此一来,该等式便把“q和k的内积”与“它们的相对位置”给串起来了
  • 也如阿荀所说,左边是含有各自绝对位置信息的q向量和k向量,而这个等式就是RoPE追求的目标,物理含义就是 通过显式传入绝对位置信息实现与传入相对位置信息对等 的情况

假定现在词嵌入向量的维度是两维 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 和 ,其形式如下:

fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]

这里面的 Re 表示复数的实部

  • 进一步地, 可以表示成下面的式子(如果此刻你觉得你有点懵,没事,下文马上会一步一步的详细推导):可能有的同学还没反应过来,怎么就叫「 query 向量乘以了一个旋转矩阵 」了呢?我再举一个来自 这里 的例子,以一目了然
    如下图所示,考虑一个矩阵,它把向量在固定坐标系中逆时针旋转一个角度 θ ,得到 v

    *而这个旋转矩阵就是
  • 同理, 可以表示成下面的式子:

fk(xm,m)=(cosmθsinmθ)sinmθcosmθ)(Wk(1,1)Wk(1,2)Wk(2,1)Wk(2,2))(xm(1)xm(2))=(cosmθsinmθ)sinmθcosmθ)(km(1)km(2))

  • 最终 g(xm,xn,mn) 可以表示如下:

g(xm,xn,mn)=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2))

然上述分别关于、、 g(xm,xn,mn) 的三个式子,咋一步一步推导来的?为做细致说明,特参考 此文 一步一步解释下


首先看第一个式子,对于 fq(xm,m)=(Wqxm)eimθ ,这个式子的右边项有两部分,一部分是、一部分是 eimθ

  1. 对于前者,可知其中的是个二维矩阵,是个二维向量,自然相乘的结果也必然是一个二维向量,用表示

qm=(qm(1)qm(2))=Wqxm=(Wq(11)Wq(12)Wq(21)Wq(22))(xm(1)xm(2))

  1. 对于后者 eimθ ,根据欧拉公式 e^{i x}=\cos x&plus;i \sin x ,可得

eimθ=cos(mθ)plus;isin(mθ)einθ=cos(nθ)plus;isin(nθ)ei(mn)θ=cos((mn)θ)plus;isin((mn)θ)

  1. 基于上面第1点结论,可知
    fq(xm,m)=(Wqxm)eimθ=qmeimθ
    然后将表示成复数形式,可得
    q_{m}=\left[q_{m}^{(1)}, q_{m}^{(2)}\right]=\left[q_{m}^{(1)}&plus;i q_{m}^{(2)}\right]
    从而有
    f_{q}\left(x_{m}, m\right)= q_{m} e^{i m \theta} = \left[q_{m}^{(1)}&plus;i q_{m}^{(2)}\right] e^{i m \theta}

基于上面第2点结论,可知 fq(xm,m) 即是两个复数相乘
f_{q}\left(x_{m}, m\right) = q_{m} e^{i m \theta}=\left(q_{m}^{(1)}&plus;i q_{m}^{(2)}\right) *(\cos (m \theta)&plus;i \sin (m \theta))
4. 考虑到以下两个关于复数的背景知识
(a&plus;i b) \cdot(c&plus;i d)=a c&plus;i b c&plus;i a d&plus;i^{2} b d=(a c-b d)&plus;i(b c&plus;a d)

可得

\begin{aligned} q_{m} e^{i m \theta} & =\left(q_{m}^{(1)}&plus;i q_{m}^{(2)}\right) *(\cos (m \theta)&plus;i \sin (m \theta)) \\ =\left(q_{m}^{(1)} \cos (m \theta)\right. & \left.-q_{m}^{(2)} \sin (m \theta)\right)&plus;i\left(q_{m}^{(2)} \cos (m \theta)&plus;q_{m}^{(1)} \sin (m \theta)\right) \end{aligned}

将这个结果表达成实数向量形式,即是
q_{m} e^{i m \theta}=\left[q_{m}^{(1)} \cos (m \theta)-q_{m}^{(2)} \sin (m \theta), q_{m}^{(2)} \cos (m \theta)&plus;q_{m}^{(1)} \sin (m \theta)\right]

至此,你也就不难发现,这不就是 query向量乘以了一个旋转矩阵

\begin{array}{c} f_{q}\left(x_{m}, m\right)=\left(W_{q} x_{m}\right) e^{i m \theta}=q_{m} e^{i m \theta} \\ =\left[q_{m}^{(1)} \cos (m \theta)-q_{m}^{(2)} \sin (m \theta), q_{m}^{(2)} \cos (m \theta)&plus;q_{m}^{(1)} \sin (m \theta)\right] \\ =\left(\begin{array}{cc} \cos (m \theta) & -\sin (m \theta) \\ \sin (m \theta) & \cos (m \theta) \end{array}\right)\left(\begin{array}{c} q_{m}^{(1)} \\ q_{m}^{(2)} \end{array}\right) \end{array}

至于第二个式子,根据上述过程同理,可得key向量

  • g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]

其中,表示一个复数的实数部分,而 (Wkxn) 则表示复数的共轭

  1. 考虑到

z=aplus;ibz=aib

再结合上面第一个式子中的推导,可得

Wqxm=qm=qm(1)plus;iqm(2)Wkxn=kn=kn(1)plus;ikn(2)(Wkxn)=kn=kn(1)ikn(2)ei(mn)θ=cos((mn)θ)plus;isin((mn)θ)

继续结合上面第一个式子中的推导(比如 (a&plus;i b) \cdot(c&plus;i d)=a c&plus;i b c&plus;i a d&plus;i^{2} b d=(a c-b d)&plus;i(b c&plus;a d) ,及),继续可知,我们现在要证明的是存在

\begin{array}{c} g\left(x_{m}, x_{n}, m-n\right)=\operatorname{Re}\left[\left(W_{q} x_{m}\right)\left(W_{k} x_{n}\right)^{*} e^{i(m-n) \theta}\right] \\ =\operatorname{Re}\left[\left(q_{m}^{(1)}&plus;i q_{m}^{(2)}\right)\left(k_{n}^{(1)}-i k_{n}^{(2)}\right)(\cos ((m-n) \theta)&plus;i \sin ((m-n) \theta))\right] \\ =\operatorname{Re}\left[\left(\left(q_{m}^{(1)} k_{n}^{(1)}&plus;q_{m}^{(2)} k_{n}^{(2)}\right)&plus;i\left(q_{m}^{(2)} k_{n}^{(1)}-q_{m}^{(1)} k_{n}^{(2)}\right)\right)(\cos ((m-n) \theta)&plus;i \sin ((m-n) \theta))\right] \\ =\left(q_{m}^{(1)} k_{n}^{(1)}&plus;q_{m}^{(2)} k_{n}^{(2)}\right) \cos ((m-n) \theta)-\left(q_{m}^{(2)} k_{n}^{(1)}-q_{m}^{(1)} k_{n}^{(2)}\right) \sin ((m-n) \theta) \end{array}

  1. 总之,接下来我们就要证明上述函数 g 的计算公式是成立的
    首先,回顾一下attention操作,位置m的query和位置n的key会做一个内积操作
    即由

\begin{array}{c} f_{q}\left(x_{m}, m\right)=\left[q_{m}^{(1)} \cos (m \theta)-q_{m}^{(2)} \sin (m \theta), q_{m}^{(2)} \cos (m \theta)&plus;q_{m}^{(1)} \sin (m \theta)\right] \\ f_{k}\left(x_{n}, n\right)=\left[k_{n}^{(1)} \cos (n \theta)-k_{n}^{(2)} \sin (n \theta), k_{n}^{(2)} \cos (n \theta)&plus;k_{n}^{(1)} \sin (n \theta)\right] \end{array}

可得

\begin{array}{c} <f_{q}\left(x_{m}, m\right), f_{k}\left(x_{n}, n\right)> \\ = \left(q_{m}^{(1)} \cos (m \theta)-q_{m}^{(2)} \sin (m \theta)\right)\left(k_{n}^{(1)} \cos (n \theta)-k_{n}^{(2)} \sin (n \theta)\right) \\ &plus;\left(q_{m}^{(2)} \cos (m \theta)&plus;q_{m}^{(1)} \sin (m \theta)\right)\left(k_{n}^{(2)} \cos (n \theta)&plus;k_{n}^{(1)} \sin (n \theta)\right) \\ =q_{m}^{(1)} \cos (m \theta) k_{n}^{(1)} \cos (n \theta)-q_{m}^{(1)} \cos (m \theta) k_{n}^{(2)} \sin (n \theta) \\ -q_{m}^{(2)} \sin (m \theta) k_{n}^{(1)} \cos (n \theta)&plus;q_{m}^{(2)} \sin (m \theta) k_{n}^{(2)} \sin (n \theta) \\ &plus;q_{m}^{(2)} \cos (m \theta) k_{n}^{(2)} \cos (n \theta)&plus;q_{m}^{(2)} \cos (m \theta) k_{n}^{(1)} \sin (n \theta) \\ &plus;q_{m}^{(1)} \sin (m \theta) k_{n}^{(2)} \cos (n \theta)&plus;q_{m}^{(1)} \sin (m \theta) k_{n}^{(1)} \sin (n \theta) \end{array}

相当于[A,B]与[C,D]做内积,则相当于A B横着,C D竖着,最终结果为AC BD,最后再把括号里的项全部对应相乘、展开
3. 首先,把上面第二点的式子整理一下,总计8项,为了把相关的项提取出来,第1项 8项合并处理、第2项 7项合并处理、第3项 6项合并处理、第4项 5项合并处理
其次,考虑到

sin(aplus;b)=sinacosbplus;cosasinbsin(ab)=sinacosbcosasinbcos(aplus;b)=cosacosbsinasinbcos(ab)=cosacosbplus;sinasinb

最后,再把相关项的特点,两次调整下顺序即可

依据以上三点,从而有

\begin{array}{c} <f_{q}\left(x_{m}, m\right), f_{k}\left(x_{n}, n\right)>\\ = q_{m}^{(1)} k_{n}^{(1)}(\cos (m \theta) \cos (n \theta)&plus;\sin (m \theta) \sin (n \theta)) \\ &plus;q_{m}^{(1)} k_{n}^{(2)}(-\cos (m \theta) \sin (n \theta)&plus;\sin (m \theta) \cos (n \theta)) \\ &plus;q_{m}^{(2)} k_{n}^{(1)}(-\sin (m \theta) \cos (n \theta)&plus;\cos (m \theta) \sin (n \theta)) \\ &plus;q_{m}^{(2)} k_{n}^{(2)}(\sin (m \theta) \sin (n \theta)&plus;\cos (m \theta) \cos (n \theta)) \\ =q_{m}^{(1)} k_{n}^{(1)} \cos ((m-n) \theta) \\ &plus;q_{m}^{(1)} k_{n}^{(2)} \sin ((m-n) \theta) \\ -q_{m}^{(2)} k_{n}^{(1)} \sin ((m-n) \theta) \\ &plus;q_{m}^{(2)} k_{n}^{(2)} \cos ((m-n) \theta) \\ =\left(q_{m}^{(1)} k_{n}^{(1)}&plus;q_{m}^{(2)} k_{n}^{(2)}\right) \cos ((m-n) \theta)&plus;\left(q_{m}^{(1)} k_{n}^{(2)}-q_{m}^{(2)} k_{n}^{(1)}\right) \sin ((m-n) \theta) \\ =\left(q_{m}^{(1)} k_{n}^{(1)}&plus;q_{m}^{(2)} k_{n}^{(2)}\right) \cos ((m-n) \theta)-\left(q_{m}^{(2)} k_{n}^{(1)}-q_{m}^{(1)} k_{n}^{(2)}\right) \sin ((m-n) \theta) \\ =g\left(x_{m}, x_{n}, m-n\right) \end{array}

完美! 如此,也就证明了,位置 m 的 query 和位置 n 的 key 的内积就是函数 g

最后,把上面的式子一、式子二的最终结果都分别用矩阵向量乘的形式来表达就是:

<fq(xm,m),fk(xn,n)>=((cos(mθ)sin(mθ)sin(mθ)cos(mθ))(qm(1)qm(2)))T((cos(nθ)sin(nθ)sin(nθ)cos(nθ))(kn(1)kn(2)))=(qm(1)qm(2))(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(cos(nθ)sin(nθ)sin(nθ)cos(nθ))(kn(1)kn(2))

接下来,我们要计算两个旋转矩阵的乘积,即中间部分的这个式子

(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(cos(nθ)sin(nθ)sin(nθ)cos(nθ))

展开之后,可得

(cos(mθ)cos(nθ)plus;sin(mθ)sin(nθ)cos(mθ)sin(nθ)plus;sin(mθ)cos(nθ)sin(mθ)cos(nθ)plus;cos(mθ)sin(nθ)sin(mθ)sin(nθ)plus;cos(mθ)cos(nθ))

从而有

<fq(xm,m),fk(xn,n)>=(qm(1)qm(2))(cos((mn)θ)sin((nm)θ)sin((nm)θ)cos((mn)θ))(kn(1)kn(2))


之前上图中第三个大括号里的「两个旋转矩阵相乘」是下面这么写的, 是不对的

<fq(xm,m),fk(xn,n)>=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2))

  • 后经本文的两位读者指出,实际应该是
    [cos(m – n)θ,-sin(n-m)θ]
    [sin(n-m)θ, cos(m – n)θ]
    也就是原论文里的R(n – m), 即下面才是对的

<fq(xm,m),fk(xn,n)>=(qm(1)qm(2))(cos((mn)θ)sin((nm)θ)sin((nm)θ)cos((mn)θ))(kn(1)kn(2))

  • 原因在于
    右上:-cos(mθ)sin(nθ) + sin(mθ)cos(nθ) = sin((m-n)θ) = -sin(n-m)θ ——正弦差公式
    左下:-sin(mθ)cos(nθ) + cos(mθ)sin(nθ) = -sin((m-n)θ) = sin(n-m)θ ——正弦差公式的负值

上面都还只是针对词嵌入维度为2的情况,那对于的通用情况呢,将2维推广到任意维度,可以表示如下:

f{q,k}(xm,m)=RΘ,mdW{q,k}xm

内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接,即将词嵌入向量元素按照两两一组分组

RΘ,md=(cosmθ0sinmθ00000sinmθ0cosmθ0000000cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθd/21sinmθd/210000sinmθd/21cosmθd/21)Wm

每组应用同样的旋转操作且每组的旋转角度计算方式如下:

Θ={θi=100002(i1)/d,i[1,2,,d/2]}

所以简单来说 RoPE 的 self-attention 操作的流程是

  1. 对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量
  2. 然后对每个 token 位置都计算对应的旋转位置编码
  3. 接着对每个 token 位置的 query 和 key 向量的元素按照 两两一组 应用旋转变换
  4. 最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果

3.1.2 第二种形式的推导(苏剑林版)

与上面第一种形式的推导类似,为了引入复数,首先假设了在加入位置信息之前,原有的编码向量是二维行向量和,其中和是绝对位置,现在需要构造一个变换,将和引入到和中,即寻找变换:

qm~=f(q,m),kn~=f(k,n)

也就是说,我们分别为、设计操作 f(,m)f(,n) ,使得经过该操作后, qm~kn~ 就带有了位置、的绝对位置信息
考虑到Attention的核心计算是内积:

Attention(Q,K,V)=softmax(QKTdk)V

故我们希望的内积的结果带有相对位置信息,即寻求的这个变换,应该具有特性:

f(q,m),f(k,n)=g(q,k,mn)

怎么理解?很简单,当m和n表示了绝对位置之后,m与n在句子中的距离即位置差m-n,就可以表示为相对位置了,且对于复数,内积通常定义为一个复数与另一个复数的共轭的乘积」

  1. 为合理的求出该恒等式的一个尽可能简单的解,可以设定一些初始条件,比如、,然后可以先考虑二维情形,然后借助复数来求解
    在复数中有 q,k=Re[qk] ,表示取实部的操作(复数 和“ 复数 的共轭即 ”之积仍是一个复数)
    *因论文100课的群里有学员对该点存在疑问,故借用七月黄老师的回复补充下:这个等式和复数乘法和向量乘积的联系有关
    考虑两个复数

    ,的共轭是
    一方面,对于等式的右边项而言
    q和k*的乘积是
    这个结果的实部是
    二方面,对于等式的左边项而言
    其对应于对应的实数向量和对应的实数向量的乘积
    [a, b] \cdot [c, d] = ac &plus; bd
    综合以上两点,可知右边项所表示的“复数q和复数k的共轭k*的乘积”,和左边项做表示的“q、k所对应向量的乘积”是一样的*

    总之,我们需要寻找一种变换,使得

  2. 简单起见,我们假设存在复数,使得,然后我们用复数的指数形式,设

    方程1:
    方程2: Θf(q,m)−Θf(k,n) = Θg(q,k,m−n)

    对于方程1,代入得到(接着,再把和都设为0)
    Rf(q,m)Rf(k,m)=Rg(q,k,0)=Rf(q,0)Rf(k,0)=qk
    最后一个等号源于初始条件和,所以现在我们可以很简单地设 Rf(q,m)=qRf(k,m)=k ,即它不依赖于

    至于方程2,同样代入得到
    Θf(q,m)−Θf(k,m) = Θg(q,k,0) = Θf(q,0)−Θf(k,0)= Θ(q)−Θ(k)

    这里的 Θ(q)Θ(k) 是、本身的幅角,而最后一个等号同样源于初始条件
    根据上式Θf(q,m)−Θf(k,m) = Θ(q)−Θ(k),可得Θf(q,m)−Θ(q)=Θf(k,m)−Θ(k),所以Θf(q,m)−Θ(q)的结果是一个只与m相关、跟q无关的函数,记为φ(m),即Θf(q,m)=Θ(q)+φ(m)

  3. 接着令n=m−1代入 Θf(q,m)−Θf(k,n) = Θg(q,k,m−n) ,可以得到 Θf(q,m)−Θf(k,m-1) = Θg(q,k,1)
    然后将 Θf(q,m) 和 Θf(k,m-1) 的等式代入Θf(q,m)=Θ(q)+φ(m),我们可以得到 Θ(q) + φ(m) – (Θ(k) + φ(m-1)) = Θg(q,k,1),整理一下就得到
    \varphi(m)-\varphi(m-1)=\Theta g(q, k, 1)&plus;\Theta(k)-\Theta(q)
    即{φ(m)}是等差数列,设右端为θ,那么就解得 φ(m)=mθ

    综上,我们得到二维情况下用复数表示的RoPE:
    \boldsymbol{f}(\boldsymbol{q}, m)=R_{f}(\boldsymbol{q}, m) e^{\mathrm{i} \Theta f(\boldsymbol{q}, m)}=\|q\| e^{\mathrm{i}(\Theta(\boldsymbol{q})&plus;m \theta)}=\boldsymbol{q} e^{\mathrm{i} m \theta}

  4. 所以说,寻求的变换就是 qmeimθ ,也就是给乘以 eimθ ,相应地,乘以 einθ
    做了这样一个变换之后,根据复数的特性,有: qm,kn=Re[qmkn] 也就是,如果把二维向量看做复数,那么它们的内积,等于一个复数乘以另一个复数的共轭,得到的结果再取实部,代入上面的变换,也就有:

    换言之,经过这样一番操作,通过给Embedding添加绝对位置信息,可以使得两个token的编码,经过内积变换(self-attn)之后,得到结果是受它们位置的差值,即相对位置影响的

于是,对于任意的位置为的二维向量,把它看做复数,乘以 eimθ ,而根据欧拉公式,有:

eimθ=cosmθplus;isinmθ

从而上述的相乘变换也就变成了(过程中注意:):

(xplus;iy)eimθ=(xplus;iy)(cosmθplus;isinmθ)=xcosmθplus;ixsinmθplus;iycosmθysinmθ=(xcosmθysinmθ)plus;i(xsinmθplus;ycosmθ)

把上述式子写成矩阵形式:

而这个变换的几何意义,就是在二维坐标系下,对向量进行了旋转,因而这种位置编码方法,被称为旋转位置编码

根据刚才的结论,结合内积的线性叠加性,可以将结论推广到高维的情形。可以理解为,每两个维度一组,进行了上述的“旋转”操作,然后再拼接在一起:

由于矩阵的稀疏性,会造成计算上的浪费,所以在计算时采用逐位相乘再相加的方式进行:

其中 为矩阵逐位相乘操作

更多还可以查看此文《 大模型长度扩展综述:从直接外推ALiBi、插值PI、NTK-aware插值(Meta称之为RoPE ABF)、YaRN到S2-Attention 》的此节「3.2.1 LLaMA2 Long相比LLaMA 2的变化:修改位置编码 长度达到32K」

3.2 旋转位置编码的coding实现(分非LLaMA版和LLaMA版两种)

原理理解了,接下来可以代码实现旋转位置编码,考虑到LLaMA本身的实现不是特别好理解,所以我们先通过一份非LLaMA实现的版本,最后再看下LLaMA实现的版本

对于,非LLaMA版的实现,其核心就是实现下面这三个函数 (再次强调,本份关于RoPE的非LLaMA版的实现 与上面和之后的代码并非一体的,仅为方便理解RoPE的实现)

3.2.1 非LLaMA版的实现

3.2.1.1 sinusoidal_position_embedding的编码实现

sinusoidal_position_embedding:这个函数用来生成正弦形状的位置编码。这种编码用来在序列中的令牌中添加关于相对或绝对位置的信息

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)

    # (output_dim//2)
    # 即公式里的i, i的范围是 [0,d/2]
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)  
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    # 即公式里的:pos / (10000^(2i/d))
    embeddings = position * theta 

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    # 在bs维度重复,其他维度都是1不重复
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  

    # (bs, head, max_len, output_dim)
    # reshape后就是:偶数sin, 奇数cos了
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings

一般的文章可能解释到这个程度基本就完事了,但 为了让初学者一目了然计,我还是再通过一个完整的示例,来一步步说明上述各个步骤都是怎么逐一结算的 ,整个过程和上文中介绍过的“transformer的位置编码” 本质上是一回事..

为方便和transformer的位置编码做对比,故这里也假定output_dim = 512

  1. 首先,我们有 ids 张量,当 output_dim 为 512 时,则









    ids = [0,0, 1,1, 2,2,…, 254,254, 255,255]
    然后我们有一个基数为10000的指数运算,使用了公式 torch.pow(10000, -2 * ids / output_dim)[1100000512,1100000512,1100002512,1100002512,1100004512,1100004512,...,110000510512,110000510512]
    [2100000512,2100000512,2100002512,2100002512,2100004512,2100004512,...,210000510512,210000510512]
  2. 接下来我们将对 embeddings 的每个元素应用 torch.sin 和 torch.cos 函数
    对于 torch.sin(embeddings),我们将取 embeddings 中的每个元素的正弦值:
    [sin(0100000512),sin(0100002512),sin(0100004512),...,sin(010000510512)]
    [sin(1100000512),sin(1100002512),sin(1100004512),...,sin(110000510512)]
    [sin(2100000512),sin(2100002512),sin(2100004512),...,sin(210000510512)]
    对于 torch.cos(embeddings),我们将取 embeddings 中的每个元素的余弦值:
    [cos(0100000512),cos(0100002512),cos(0100004512),...,,cos(010000510512)]
    [cos(1100000512),cos(1100002512),cos(1100004512),...,,cos(110000510512)]
    [cos(2100000512),cos(2100002512),cos(2100004512),...,,cos(210000510512)]
    最后,torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) 将这两个新的张量沿着一个新的维度堆叠起来,得到的 embeddings如下PE_1 = [sin(\frac{1}{10000^{\frac{0}{512}}}),cos(\frac{1}{10000^{\frac{0}{512}}}), sin(\frac{1}{10000^{\frac{2}{512}}}),cos(\frac{1}{10000^{\frac{2}{512}}}), sin(\frac{1}{10000^{\frac{4}{512}}}), cos(\frac{1}{10000^{\frac{4}{512}}}),…, sin(\frac{1}{10000^{\frac{510}{512}}}),cos(\frac{1}{10000^{\frac{510}{512}}})]PE_2 = [sin(\frac{2}{10000^{\frac{0}{512}}}),cos(\frac{2}{10000^{\frac{0}{512}}}), sin(\frac{2}{10000^{\frac{2}{512}}}),cos(\frac{2}{10000^{\frac{2}{512}}}), sin(\frac{2}{10000^{\frac{4}{512}}}), cos(\frac{2}{10000^{\frac{4}{512}}}),…, sin(\frac{2}{10000^{\frac{510}{512}}}),cos(\frac{2}{10000^{\frac{510}{512}}})]

    [
      [
        [
          [sin(\frac{0}{10000^{\frac{0}{512}}}), cos(\frac{0}{10000^{\frac{0}{512}}}), sin(\frac{0}{10000^{\frac{2}{512}}}), cos(\frac{0}{10000^{\frac{2}{512}}}), ..., cos(\frac{0}{10000^{\frac{510}{512}}})],
          [sin(\frac{1}{10000^{\frac{0}{512}}}), cos(\frac{1}{10000^{\frac{0}{512}}}), sin(\frac{1}{10000^{\frac{2}{512}}}), cos(\frac{1}{10000^{\frac{2}{512}}}), ..., cos(\frac{1}{10000^{\frac{510}{512}}})],
          [sin(\frac{2}{10000^{\frac{0}{512}}}), cos(\frac{2}{10000^{\frac{0}{512}}}), sin(\frac{2}{10000^{\frac{2}{512}}}), cos(\frac{2}{10000^{\frac{2}{512}}}), ..., cos(\frac{2}{10000^{\frac{510}{512}}})]
        ]
      ]
    ]
    
3.2.1.2 RoPE的编码实现

RoPE:这个函数将相对位置编码(RoPE)应用到注意力机制中的查询和键上。这样,模型就可以根据相对位置关注不同的位置

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def RoPE(q, k):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)

    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了

    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k

老规矩,为一目了然起见,还是一步一步通过一个示例来加深理解

  1. sinusoidal_position_embedding函数生成位置嵌入。在output_dim=512的情况下,每个位置的嵌入会有512个维度,但为了简单起见,我们只考虑前8个维度,前4个维度为sin编码,后4个维度为cos编码。所以,我们可能得到类似以下的位置嵌入
    # 注意,这只是一个简化的例子,真实的位置嵌入的值会有所不同。
    pos_emb = torch.tensor([[[[0.0000, 0.8415, 0.9093, 0.1411, 1.0000, 0.5403, -0.4161, -0.9900],
                              [0.8415, 0.5403, 0.1411, -0.7568, 0.5403, -0.8415, -0.9900, -0.6536],
                              [0.9093, -0.4161, -0.8415, -0.9589, -0.4161, -0.9093, -0.6536, 0.2836]]]])
    
  2. 然后,我们提取出所有的sin位置编码和cos位置编码,并在最后一个维度上每个位置编码进行复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 提取出所有sin编码,并在最后一个维度上复制
    cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)  # 提取出所有cos编码,并在最后一个维度上复制
    
  3. 更新query向量
    我们首先构建一个新的q2向量,这个向量是由原来向量的负的cos部分和sin部分交替拼接而成的
    我们用cos_pos对q进行元素级乘法,用sin_pos对q2进行元素级乘法,并将两者相加得到新的query向量

    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).flatten(start_dim=-2)
    # q2: tensor([[[[-0.2,  0.1, -0.4,  0.3, -0.6,  0.5, -0.8,  0.7],
    #               [-1.0,  0.9, -1.2,  1.1, -1.4,  1.3, -1.6,  1.5],
    #               [-1.8,  1.7, -2.0,  1.9, -2.2,  2.1, -2.4,  2.3]]]])
    q = q * cos_pos + q2 * sin_pos
    

    公式表示如下

  4. 更新key向量
    对于key向量,我们的处理方法与query向量类似

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).flatten(start_dim=-2)
    # k2: tensor([[[[-0.15,  0.05, -0.35,  0.25, -0.55,  0.45, -0.75,  0.65
    
3.2.1.3 attention的编码实现

attention:这是注意力机制的主要功能

  • 首先,如果use_RoPE被设置为True,它会应用RoPE,通过取查询和键的点积(并进行缩放)
  • 然后,进行softmax操作来计算注意力分数,以得到概率,输出是值的加权和,权重是计算出的概率
  • 最后,旋转后的q和k计算点积注意力后,自然就具备了相对位置信息
def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)

    if use_RoPE:
        # 使用RoPE进行位置编码
        q, k = RoPE(q, k)

    d_k = k.size()[-1]

    # 计算注意力权重
    # (bs, head, seq_len, seq_len)
    att_logits = torch.matmul(q, k.transpose(-2, -1))  
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        # 对权重进行mask,将为0的部分设为负无穷大
        att_scores = att_logits.masked_fill(mask == 0, -1e-9)  

    # 对权重进行softmax归一化
    # (bs, head, seq_len, seq_len)
    att_scores = F.softmax(att_logits, dim=-1)  

    if dropout is not None:
        # 对权重进行dropout
        att_scores = dropout(att_scores)

    # 注意力权重与值的加权求和
    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores

if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    # 进行注意力计算
    res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)

    # 输出结果的形状
    # (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
    print(res.shape, att_scores.shape)

3.2.2 LLaMA版的实现

接下来,我们再来看下LLaMA里是怎么实现这个旋转位置编码的,具体而言, LLaMA 的model.py 文件里面实现了旋转位置编码(为方便大家理解,我给相关代码 加了下注释)
首先,逐一实现这三个函数
precompute_freqs_cis
reshape_for_broadcast
apply_rotary_emb

# 预计算频率和复数的函数
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))    # 计算频率
    t = torch.arange(end, device=freqs.device)    # 根据结束位置生成序列
    freqs = torch.outer(t, freqs).float()    # 计算外积得到新的频率
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)    # 计算复数
    return freqs_cis    # 返回复数
# 重塑的函数
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim    # 获取输入张量的维度
    assert 0 <= 1 < ndim    # 检查维度的合理性
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])    # 检查复数的形状
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]    # 计算新的形状
    return freqs_cis.view(*shape)    # 重塑复数的形状并返回
# 应用旋转嵌入的函数
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))    # 将xq视为复数
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))    # 将xk视为复数
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)    # 重塑复数的形状
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)    # 计算xq的输出
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)    # 计算xk的输出
    return xq_out.type_as(xq), xk_out.type_as(xk)    # 返回xq和xk的输出

之后,在注意力机制的前向传播函数中调用上面实现的第三个函数 apply_rotary_emb,赋上位置信息

# 对Query和Key应用旋转嵌入
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

后记

最后,说明下为何像开头说的是「23年12.16日这天对本文做了大修」呢,原因在于

  1. 我司《 论文审稿GPT第2版 》即将进入模型训练阶段,其涉及到三个候选模型:mistral-yarn、mistral、llama-longlora
    故准备解析下YaRN,顺带把外推、内插都全面介绍下,而过程中不可避免会提到RoPE,故也总算把RoPE彻底写清楚了
  2. 这些东西,哪怕是近期最新的技术、模型等理解了后 会发现都不难,但我总想把理解的门槛无限降低,所以 想真正写清楚或讲清楚一个东西,必须得反复琢磨、反复修改,以让更多人因此看懂,更何况当我和我的团队每天看paper、做项目,更可以帮到大家不断进阶、深入

如今博客的访问PV2000万,希望明年达到2000万UV以上,以上视为后记

参考文献与推荐阅读

  1. 马同学关于向量和欧拉公式的几篇科普文章
    向量的加法
    欧拉公式,复数域的成人礼
  2. 关于欧拉公式的几篇文章
    被众人膜拜的欧拉恒等式是个什么东东?
    怎么向小学生解释欧拉公式 e^(πi)+1=0?
  3. 读懂旋转编码(RoPE)
  4. LLM学习记录(五)–超简单的RoPE理解方式 ,这篇文章很不错
  5. 苏剑林: Transformer升级之路:2、博采众长的旋转式位置编码
  6. LLaMA的解读与其微调:Alpaca-LoRA/Vicuna/BELLE/中文LLaMA/姜子牙/LLaMA 2
  7. 关于ALiBi的两篇文章
    [速读经典]ALiBi – 给注意力加上线性偏置
    关于Transformer中的位置编码-ALiBi
  8. 最强LLaMA突然来袭!只改一个超参数,实现上下文3.2万token,多个任务打败ChatGPT、Claude 2

openEuler 社区

openEuler 是由开放原子开源基金会孵化的全场景开源操作系统项目,面向数字基础设施四大核心场景(服务器、云计算、边缘计算、嵌入式),全面支持 ARM、x86、RISC-V、loongArch、PowerPC、SW-64 等多样性计算架构

v_JULY_v

已为社区贡献1条内容

开源官网

目录

回到
顶部

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇