MOE基本类型

MOE全称Mixtrue of Experts (混合专家)

简单来说,就是使用MOE替换Transformer之前的FFN(前馈网络)结构,从而获取更多的信息。

Transformer直接的堆叠,在层数深了之后会导致提取特征的低秩现象。添加残差和前馈神经网络(FFN)可以避免这种现象。

MOE的基本结构如下图:

这里的FFN1,...,FFN4FFN1,...,FFN4就是不同的4个专家

对于输入X1,X2X_1,X_2,首先会通过一个门控网络决定走哪一个FFN,然后将计算结果与对应的门控权重pp进行乘积,将结果与输入进行残差连接和归一化,最后输出结果。

image-20250302125854893

稀疏MOE与稠密MOE

image-20250302130615089

  1. 门控函数(路由函数):协调专家与其各自输出组合

  2. 门控函数的分类

    1. 稀疏门控:激活部分专家

    2. 稠密门控:激活所有专家

    3. 软门控:输入Token合并和专家合并

      image-20250302130838175

DeepSeek MOE

  1. Expert共享机制:部分Expert在不同Token或层间共享参数,减少模型冗余。
  2. Expert共享多了一个Expert作为Shared Expert,每个Expert的计算都使用它
  3. image-20250302132314094
  4. 内存优化:MLA+KV Cache优化,减少生成任务中的浮点运算量

DeepSeek V2

MLA

multi-head latent attention

对比算法MQA(multi-query attention),GQA(Grouped-Query Attention),MHA(multi-head Attention)

MQA:所有的Q共用一个KV

GQA:对Q进行分组,相同组内共用KV

MHA:对每个Q都使用不同的KV

image-20250302143108329

Flash Attention

参考 [Attention优化][2w字]🔥原理篇: 从Online-Softmax到FlashAttention V1/V2/V3 - 知乎 (zhihu.com)

Online-Softmax

save-softmax

softmax(x1,x2,...,xn)=eximj=1nexjm,m=max(x1,x2,...,xn)softmax({x_1,x_2,...,x_n})=\frac{e^{x_i-m}}{\sum_{j=1}^n e^{x_j-m}},m=max({x_1,x_2,...,x_n})

3-pass softmax

  1. 先算最大值:mimax(mi1,xi)m_i \leftarrow max(m_{i-1}, x_i)
  2. 再算求和: didi1+eximNd_i \leftarrow d_{i-1}+e^{x_i-m_N}
  3. 最后算每个元素的值: resieximdNres_i \leftarrow \frac {e^{x_i}-m}{d_N}

2-pass online-softmax

合并1-2步

i==Ni==N是刚好满足3-pass的第二步

di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1×emi1mi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi\begin{aligned} d_i &= \sum_{j=1}^{i}e^{x_j-m_i}\\ &= (\sum_{j=1}^{i-1}e^{x_j-m_i}) + e^{x_i-m_i}\\ &= (\sum_{j=1}^{i-1}e^{x_j-m_{i-1}} \times e^{m_{i-1}-m_i}) + e^{x_i-m_i} \\ &= (\sum_{j=1}^{i-1}e^{x_j-m_{i-1}} )e^{m_{i-1}-m_i} + e^{x_i-m_i} \\ &= d_{i-1}e^{m_{i-1}-m_i} + e^{x_i-m_i} \end{aligned}

这样只用存储上一步的最小值mi1m_{i-1}和上一步的结果di1d_{i-1}就可以计算当前这一步了,规避了计算全局max

Self-Attention

Standard self-Attention

O=softmax(QKTd)VO = softmax(\frac{QK^T}{\sqrt{d}})V,后续默认d=1\sqrt{d}=1

多阶段计算

S=QKTP=softmax(S)O=PVS = QK^T\\ P=softmax(S)\\ O=PV

Multi-pass Self-Attention

利用2-pass online-softmax的基础上添加QKV计算

  1. 计算QK

    xiQ[k,:]KT[:,i]mimax(mi1,xi)didi1emi1mi+eximi\begin{aligned} x_i & \leftarrow Q[k,:] K^T[:, i] \\ m_i & \leftarrow \max \left(m_{i-1}, x_i\right) \\ d_i^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned}

  2. 计算结果

    aieximNdNoioi1+aiV[i,:]\begin{aligned} & a_i \leftarrow \frac{e^{x_i-m_N}}{d_N^{\prime}} \\ & \boldsymbol{o}_i \leftarrow \boldsymbol{o}_{i-1}+a_i V[i,:] \end{aligned}

    可以变化为

    oioi1+exjmNdNV[i,:]o_i \leftarrow \boldsymbol{o}_{i-1}+\frac{e^{x_j-m_N}}{d_N^{\prime}} V[i,:]

1-pass FlashAttention v1(核心)

oi=(j=1iexjmidiV[j,:])\boldsymbol{o}_i=\left(\sum_{j=1}^i \frac{e^{x_j-m_i}}{d_i^{\prime}} V[j,:]\right)

i==Ni==N时,刚好和multi-pass self-attention的第二步相同

同样的,可以推导出oio_ioi1o_{i-1}的关系

oi=(j=1iexjmidiV[j,:])=(j=1i1exjmidiV[j,:])+eximidiV[i,:]=(j=1i1exjmi1di1exjmiexjmi1di1diV[j,:])+eximidiV[i,:]=(j=1i1exjmi1di1di1emi1midiV[j,:])+eximidiV[i,:]=(j=1i1exjmi1di1V[j,:])di1emi1midi+eximidiV[i,:]=oi1di1emi1midi+eximidiV[i,:]\begin{aligned} o_i &=(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d^{\prime}_i} V[j,:])\\ &=(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d^{\prime}_i} V[j,:]) + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:]\\ &=(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d^{\prime}_{i-1}} \frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}} \frac{d^{\prime}_{i-1}}{d^{\prime}_i} V[j,:]) + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:] \\ &=(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d^{\prime}_{i-1}} \frac{d^{\prime}_{i-1}e^{m_{i-1}-m_i}}{d^{\prime}_i} V[j,:]) + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:] \\ &=(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d^{\prime}_{i-1}} V[j,:])\frac{d^{\prime}_{i-1}e^{m_{i-1}-m_i}}{d^{\prime}_i} + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:] \\ &=o_{i-1}\frac{d^{\prime}_{i-1}e^{m_{i-1}-m_i}}{d^{\prime}_i} + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:] \\ \end{aligned}

这样就可以看到oio_i只依赖于上一次的di1,mi1,oi1d_{i-1},m_{i-1},o_{i-1}与本次的di,mid_i,m_i,可以在一个循环中全部计算完成,就不用两阶段了

所以得到1-pass的计算为

xiQ[k,:]KT[:,i]mimax(mi1,xi)didi1emi1mi+eximioioi1di1emi1midi+eximidiV[i,:]\begin{aligned} x_i & \leftarrow Q[k,:] K^T[:, i] \\ m_i & \leftarrow \max \left(m_{i-1}, x_i\right) \\ d_i^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ o_i & \leftarrow o_{i-1}\frac{d^{\prime}_{i-1}e^{m_{i-1}-m_i}}{d^{\prime}_i} + \frac{e^{x_i-m_i}}{d^{\prime}_i} V[i,:] \end{aligned}

上面的伪代码是按列进行计算的,外层循环是要遍历列的。当然这个步骤可以进行分块计算(tiling)这样可以减少外层循环的次数,还可以增强数据的访问效率。

这里取K时多取了几列,计算mim_i时先计算每一列的最大值,最后是did_i后面加eximie^{x_i-m_i}的部分修改成了对当前块所有列的求和

xiQ[k,:]KT[:,(i1)b:ib]mi(local)=maxj=1b(xi[j]),每一列的最大值mimax(mi1,mi(local))didi1emi1mi+j=1bexi[j]mioioi1di1emi1midi+j=1bexi[j]midiV[j+(i1)b,:]\begin{aligned} x_i & \leftarrow Q[k,:] K^T[:, (i-1)b:ib] \\ m_i^{(local)} & =max_{j=1}^{b}(x_i[j]),每一列的最大值\\ m_i & \leftarrow \max \left(m_{i-1}, m_i^{(local)}\right) \\ d_i^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_i}+ \sum_{j=1}^{b} e^{x_i[j]-m_i}\\ o_i & \leftarrow o_{i-1}\frac{d^{\prime}_{i-1}e^{m_{i-1}-m_i}}{d^{\prime}_i} + \sum_{j=1}^{b}\frac{e^{x_i[j]-m_i}}{d^{\prime}_i} V[j+(i-1)b,:] \end{aligned}

相比于Standard Self-Attention的计算流程节省了S和P矩阵的显存,减少Q,K的HBM IO

block size的设置

MM是SRAM的大小,也是L1 cache的大小,通过这样的计算方式,控制Q,O,K,V的大小不会超过SRAM的大小,实现高效的访存。

下面可以推断出通过这样设置BcB_cBrB_r确保Q,O,K,V的中间变量大小都不会超过M4\frac M4

当然这里也会有一些剩余的部分,其实就是给mim_idid_i预留使用,基本上都可以把SRAM打满了

Bc=M4d,Br=min(M4d,d)SRAM(Qi)=Br×d=min(M4d,d)×d<M4SRAM(Oi)=Br×d=min(M4d,d)×d<M4SRAM(Kj,Vj)=2×Bc×d=2×M4d×d<M2B_c=\left\lceil\frac{M}{4 d}\right\rceil, B_r=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right)\\ \begin{aligned} & S R A M\left(Q_i\right)=B_r \times d=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right) \times d<\left\lceil\frac{M}{4}\right\rceil \\ & S R A M\left(O_i\right)=B_r \times d=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right) \times d<\left\lceil\frac{M}{4}\right\rceil \\ & S R A M\left(K_j, V_j\right)=2 \times B_c \times d=2 \times\left\lceil\frac{M}{4 d}\right\rceil \times d<\left\lceil\frac{M}{2}\right\rceil \end{aligned}

稀疏矩阵的拓展

简单来说就是在原本的基础上,检测每次分块的稀疏度,如果稀疏度为0就跳过对这个小块的计算

反向计算

反向最重要的技术就是recompute,前向计算中省略了中间结果S和P,但是反向需要用他们计算梯度值。

所以反向计算时也会进行tiling,将Q,K,V分块加载到SRAM,再通过recompute的方式计算出当前块的S和P的值,用于求取梯度。

无论是否有recompute,都要去load对应的数据到SRAM中,如果不用recompute就要从HBM中拉取(load Q,K,V,dO,dS, )+ (load P,dP) +(write dS, dP, dQ, dV, dK)

但使用了tiling+recompute之后,只用从HBM拉取(load Q,K,V,dO) + (write dQ,dV,dK),节省了dS,dP,P的IO,虽然recompute技术增加了计算量,但计算过程都是在SRAM中进行的,对比与从HBM拉取数据到SRAM中速度能快很多。

FlashAttention V2

主要优化:

  1. 减少大量非matmul的冗余计算,增加Tensor Cores运算比例
  2. forward pass/backward pass均增加seqlen维度的并行,forward pass交替Q,K,V循环顺序
  3. 更好的Warp Partitioning策略,避免Split-K

减少非matmul的冗余计算

为啥要减少非matmul?因为matmul有专门的硬件(tensor core),可以算得更快。

哪里能减少matmul?就是V1中每一轮迭代都进行了rescale(就是softmax中的分母部分),这东西在V1版本中每一轮都用了,但其实可以在QKV都算完之后再除以的

FA2的计算过程如下:

m(1)=rowmax(S(1))RBr,第一块每一行的最大值(1)=rowsum(eS(1)m(1))RBr,第一块每一行的求和O~(1)=eS(1)m(1)V(1)RBr×d,第一块的结果(无rescalem(2)=max(m(1),rowmax(S(2)))=m,前两块的最大值(2)=em(1)m(2)(1)+rowsum(eS(2)m(2))=rowsum(eS(1)m)+rowsum(eS(2)m)=,前两块的求和O~(2)=es(1)mV(1)+es(2)mV(2),前两块的结果(无rescale)O(2)=diag((2))1O~(2)=O,前两块的结果(带rescale\begin{aligned} & m^{(1)}=\operatorname{rowmax}\left(\mathbf{S}^{(1)}\right) \in \mathbb{R}^{B_r}, 第一块每一行的最大值 \\ & \ell^{(1)}=\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m^{(1)}}\right) \in \mathbb{R}^{B_r},第一块每一行的求和 \\ & \tilde{\mathbf{O}}^{(1)}=e^{\mathbf{S}^{(1)}-m^{(1)}} \mathbf{V}^{(1)} \in \mathbb{R}^{B_r \times d},第一块的结果(无rescale) \\ & m^{(2)}=\max \left(m^{(1)}, \operatorname{rowmax}\left(\mathbf{S}^{(2)}\right)\right)=m ,前两块的最大值\\ & \ell^{(2)}=e^{m^{(1)}-m^{(2)}} \ell^{(1)}+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m^{(2)}}\right)=\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m}\right)+\operatorname{rowsum} \left(e^{\mathbf{S}^{(2)}-m}\right)=\ell,前两块的求和 \\ & \tilde{\mathbf{O}}^{(2)}=e^{s^{(1)}-m} \mathbf{V}^{(1)}+e^{s^{(2)}-m} \mathbf{V}^{(2)} ,前两块的结果(无rescale)\\ & \mathbf{O}^{(2)}=\operatorname{diag}\left(\ell^{(2)}\right)^{-1} \tilde{\mathbf{O}}^{(2)}=\mathbf{O} ,前两块的结果(带rescale) \end{aligned}

对比与FA1,计算区别在于那几个OO的计算,在FA1中每一轮的计算都为,多了diag(l)diag(l)的计算步骤

Oidiag(inew )1(diag(i)emiminew Oi+em~ijminew P~ijVj)\mathbf{O}_i\leftarrow\operatorname{diag}\left(\ell_i^{\text {new }}\right)^{-1}\left(\operatorname{diag}\left(\ell_i\right) e^{m_i-m_i^{\text {new }}} \mathbf{O}_i+e^{\tilde{m}_{i j}-m_i^{\text {new }}} \tilde{\mathbf{P}}_{i j} \mathbf{V}_j\right)

反向计算不再保存m(j)m^{(j)}(j)\ell^{(j)},而是保存logsumexpL(j)=m(j)+log((j))\operatorname{logsumexp} L^{(j)}=m^{(j)}+\log \left(\ell^{(j)}\right)

从而减少PijP_{ij}重计算的计算量。FA1→FA Pij=diag(li)1exp(Sijmasked mi)RBr×BcPi(j)=exp(SijLi)RBr×Bc\left.\begin{array}{c} \mathbf{P}_{i j}=\operatorname{diag}\left(l_i\right)^{-1} \exp \left(\mathbf{S}_{i j}^{\text {masked }}\right. \end{array}-m_i\right) \in \mathbb{R}^{B_r \times \boldsymbol{B}_c} \rightarrow \mathbf{P}_i^{(j)}=\exp \left(\mathbf{S}_{i j}-L_i\right) \in \mathbb{R}^{B_r \times \boldsymbol{B}_c}

增加seqlen维度的并行

在FA1中,是先load K,V子块,再load Q子块,这使得内循环每轮迭代都只是计算了Q的子结果,想要计算完所有的Q的每一行则是需要整个计算过程结束,此外每一次内循环,都要将结果写入到全局内存中,访问开销很大。

在FA2中,是先load Q子块,再load K,V字块,这使得只要内循环结束,Q的一部分行就能计算完成,而不只是Q的子结果。如果我们在外循环中使用一个本地内存去存储O,然后内循环计算的结果全部写入本地内存中,内循环结束后再写入全局内存,阁下又该如何应对?这种方式是对的,这样每个Q行就可以并行起来了,独立去计算这一行Q的结果,每一行Q都会在内循环结束后得到完整的结果,此外每次内循环访问的都是本地内存,访存开销也大大减少

反观FA1,只在batch_size和headnum做并行

image-20250323191408636

反向中使用的仍然是先load K,V再load O, 为啥呢,因为每个KV都参与了O的计算,因此要得到dK和dV则需要考虑所有的O

image-20250323192951932

warp partition

FA1可以看做一种split-k的计算,FA2可以看做splitQ的计算,那这两个计算的差异有什么区别呢?

image-20250323194638759

这里看到split-K的做法,会导致在计算QKV时需要跨warp同步通信才能得到最终的结果O,而split-Q的做法,则直接在warp内进行访存计算,消除了通信开销

image-20250323194740692

causal mask的处理

causal mask就是在attention计算结果上乘以一个下三角为1上三角为0的矩阵,这时候flash attention计算过程可以包含这一步骤。

他有三种情况

  1. FA的子块对应的causal mask全为0,那这个FA的子块就可以不计算了,因为就是0
  2. FA的子块对应的causal mask全为1,那这个FA的子块计算完成后,可以不与causal mask相乘,因为就是原值
  3. 比较麻烦的就是FA的子块在全局网格的对角线上,这样的子块对应的casual mask一半是1,一半是0,需要在最后计算结果中乘以causal mask

但如果子块不是矩阵又该如何处理,这里引入一个右对齐的概念,看下图吧:

image-20250323201849316