MOE基本类型
MOE全称Mixtrue of Experts (混合专家)
简单来说,就是使用MOE替换Transformer之前的FFN(前馈网络)结构,从而获取更多的信息。
Transformer直接的堆叠,在层数深了之后会导致提取特征的低秩现象。添加残差和前馈神经网络(FFN)可以避免这种现象。
MOE的基本结构如下图:
这里的FFN1,...,FFN4就是不同的4个专家
对于输入X1,X2,首先会通过一个门控网络决定走哪一个FFN,然后将计算结果与对应的门控权重p进行乘积,将结果与输入进行残差连接和归一化,最后输出结果。

稀疏MOE与稠密MOE

-
门控函数(路由函数):协调专家与其各自输出组合
-
门控函数的分类
-
稀疏门控:激活部分专家
-
稠密门控:激活所有专家
-
软门控:输入Token合并和专家合并

DeepSeek MOE
- Expert共享机制:部分Expert在不同Token或层间共享参数,减少模型冗余。
- Expert共享多了一个Expert作为Shared Expert,每个Expert的计算都使用它

- 内存优化:MLA+KV Cache优化,减少生成任务中的浮点运算量
DeepSeek V2
MLA
multi-head latent attention
对比算法MQA(multi-query attention),GQA(Grouped-Query Attention),MHA(multi-head Attention)
MQA:所有的Q共用一个KV
GQA:对Q进行分组,相同组内共用KV
MHA:对每个Q都使用不同的KV

Flash Attention
参考 [Attention优化][2w字]🔥原理篇: 从Online-Softmax到FlashAttention V1/V2/V3 - 知乎 (zhihu.com)
Online-Softmax
save-softmax
softmax(x1,x2,...,xn)=∑j=1nexj−mexi−m,m=max(x1,x2,...,xn)
3-pass softmax
- 先算最大值:mi←max(mi−1,xi)
- 再算求和: di←di−1+exi−mN
- 最后算每个元素的值: resi←dNexi−m
2-pass online-softmax
合并1-2步
i==N是刚好满足3-pass的第二步
di=j=1∑iexj−mi=(j=1∑i−1exj−mi)+exi−mi=(j=1∑i−1exj−mi−1×emi−1−mi)+exi−mi=(j=1∑i−1exj−mi−1)emi−1−mi+exi−mi=di−1emi−1−mi+exi−mi
这样只用存储上一步的最小值mi−1和上一步的结果di−1就可以计算当前这一步了,规避了计算全局max
Self-Attention
Standard self-Attention
O=softmax(dQKT)V,后续默认d=1
多阶段计算
S=QKTP=softmax(S)O=PV
Multi-pass Self-Attention
利用2-pass online-softmax的基础上添加QKV计算
-
计算QK
ximidi′←Q[k,:]KT[:,i]←max(mi−1,xi)←di−1′emi−1−mi+exi−mi
-
计算结果
ai←dN′exi−mNoi←oi−1+aiV[i,:]
可以变化为
oi←oi−1+dN′exj−mNV[i,:]
1-pass FlashAttention v1(核心)
令oi=(∑j=1idi′exj−miV[j,:])
当i==N时,刚好和multi-pass self-attention的第二步相同
同样的,可以推导出oi与oi−1的关系
oi=(j=1∑idi′exj−miV[j,:])=(j=1∑i−1di′exj−miV[j,:])+di′exi−miV[i,:]=(j=1∑i−1di−1′exj−mi−1exj−mi−1exj−midi′di−1′V[j,:])+di′exi−miV[i,:]=(j=1∑i−1di−1′exj−mi−1di′di−1′emi−1−miV[j,:])+di′exi−miV[i,:]=(j=1∑i−1di−1′exj−mi−1V[j,:])di′di−1′emi−1−mi+di′exi−miV[i,:]=oi−1di′di−1′emi−1−mi+di′exi−miV[i,:]
这样就可以看到oi只依赖于上一次的di−1,mi−1,oi−1与本次的di,mi,可以在一个循环中全部计算完成,就不用两阶段了
所以得到1-pass的计算为
ximidi′oi←Q[k,:]KT[:,i]←max(mi−1,xi)←di−1′emi−1−mi+exi−mi←oi−1di′di−1′emi−1−mi+di′exi−miV[i,:]
上面的伪代码是按列进行计算的,外层循环是要遍历列的。当然这个步骤可以进行分块计算(tiling)这样可以减少外层循环的次数,还可以增强数据的访问效率。
这里取K时多取了几列,计算mi时先计算每一列的最大值,最后是di后面加exi−mi的部分修改成了对当前块所有列的求和
ximi(local)midi′oi←Q[k,:]KT[:,(i−1)b:ib]=maxj=1b(xi[j]),每一列的最大值←max(mi−1,mi(local))←di−1′emi−1−mi+j=1∑bexi[j]−mi←oi−1di′di−1′emi−1−mi+j=1∑bdi′exi[j]−miV[j+(i−1)b,:]
相比于Standard Self-Attention的计算流程节省了S和P矩阵的显存,减少Q,K的HBM IO
block size的设置
M是SRAM的大小,也是L1 cache的大小,通过这样的计算方式,控制Q,O,K,V的大小不会超过SRAM的大小,实现高效的访存。
下面可以推断出通过这样设置Bc和Br确保Q,O,K,V
的中间变量大小都不会超过4M。
当然这里也会有一些剩余的部分,其实就是给mi和di预留使用,基本上都可以把SRAM打满了
Bc=⌈4dM⌉,Br=min(⌈4dM⌉,d)SRAM(Qi)=Br×d=min(⌈4dM⌉,d)×d<⌈4M⌉SRAM(Oi)=Br×d=min(⌈4dM⌉,d)×d<⌈4M⌉SRAM(Kj,Vj)=2×Bc×d=2×⌈4dM⌉×d<⌈2M⌉
稀疏矩阵的拓展
简单来说就是在原本的基础上,检测每次分块的稀疏度,如果稀疏度为0就跳过对这个小块的计算
反向计算
反向最重要的技术就是recompute,前向计算中省略了中间结果S和P,但是反向需要用他们计算梯度值。
所以反向计算时也会进行tiling,将Q,K,V分块加载到SRAM,再通过recompute的方式计算出当前块的S和P的值,用于求取梯度。
无论是否有recompute,都要去load对应的数据到SRAM中,如果不用recompute就要从HBM中拉取(load Q,K,V,dO,dS, )+ (load P,dP) +(write dS, dP, dQ, dV, dK)
但使用了tiling+recompute之后,只用从HBM拉取(load Q,K,V,dO) + (write dQ,dV,dK),节省了dS,dP,P的IO,虽然recompute技术增加了计算量,但计算过程都是在SRAM中进行的,对比与从HBM拉取数据到SRAM中速度能快很多。
FlashAttention V2
主要优化:
- 减少大量非matmul的冗余计算,增加Tensor Cores运算比例
- forward pass/backward pass均增加seqlen维度的并行,forward pass交替Q,K,V循环顺序
- 更好的Warp Partitioning策略,避免Split-K
减少非matmul的冗余计算
为啥要减少非matmul?因为matmul有专门的硬件(tensor core),可以算得更快。
哪里能减少matmul?就是V1中每一轮迭代都进行了rescale(就是softmax中的分母部分),这东西在V1版本中每一轮都用了,但其实可以在QKV都算完之后再除以的
FA2的计算过程如下:
m(1)=rowmax(S(1))∈RBr,第一块每一行的最大值ℓ(1)=rowsum(eS(1)−m(1))∈RBr,第一块每一行的求和O~(1)=eS(1)−m(1)V(1)∈RBr×d,第一块的结果(无rescale)m(2)=max(m(1),rowmax(S(2)))=m,前两块的最大值ℓ(2)=em(1)−m(2)ℓ(1)+rowsum(eS(2)−m(2))=rowsum(eS(1)−m)+rowsum(eS(2)−m)=ℓ,前两块的求和O~(2)=es(1)−mV(1)+es(2)−mV(2),前两块的结果(无rescale)O(2)=diag(ℓ(2))−1O~(2)=O,前两块的结果(带rescale)
对比与FA1,计算区别在于那几个O的计算,在FA1中每一轮的计算都为,多了diag(l)的计算步骤
Oi←diag(ℓinew )−1(diag(ℓi)emi−minew Oi+em~ij−minew P~ijVj)
反向计算不再保存m(j)和ℓ(j),而是保存logsumexpL(j)=m(j)+log(ℓ(j))
从而减少Pij重计算的计算量。FA1→FA Pij=diag(li)−1exp(Sijmasked −mi)∈RBr×Bc→Pi(j)=exp(Sij−Li)∈RBr×Bc
增加seqlen维度的并行
在FA1中,是先load K,V子块,再load Q子块,这使得内循环每轮迭代都只是计算了Q的子结果,想要计算完所有的Q的每一行则是需要整个计算过程结束,此外每一次内循环,都要将结果写入到全局内存中,访问开销很大。
在FA2中,是先load Q子块,再load K,V字块,这使得只要内循环结束,Q的一部分行就能计算完成,而不只是Q的子结果。如果我们在外循环中使用一个本地内存去存储O,然后内循环计算的结果全部写入本地内存中,内循环结束后再写入全局内存,阁下又该如何应对?这种方式是对的,这样每个Q行就可以并行起来了,独立去计算这一行Q的结果,每一行Q都会在内循环结束后得到完整的结果,此外每次内循环访问的都是本地内存,访存开销也大大减少。
反观FA1,只在batch_size和headnum做并行

反向中使用的仍然是先load K,V再load O, 为啥呢,因为每个KV都参与了O的计算,因此要得到dK和dV则需要考虑所有的O

warp partition
FA1可以看做一种split-k的计算,FA2可以看做splitQ的计算,那这两个计算的差异有什么区别呢?

这里看到split-K的做法,会导致在计算QKV时需要跨warp同步通信才能得到最终的结果O,而split-Q的做法,则直接在warp内进行访存计算,消除了通信开销

causal mask的处理
causal mask就是在attention计算结果上乘以一个下三角为1上三角为0的矩阵,这时候flash attention计算过程可以包含这一步骤。
他有三种情况
- FA的子块对应的causal mask全为0,那这个FA的子块就可以不计算了,因为就是0
- FA的子块对应的causal mask全为1,那这个FA的子块计算完成后,可以不与causal mask相乘,因为就是原值
- 比较麻烦的就是FA的子块在全局网格的对角线上,这样的子块对应的casual mask一半是1,一半是0,需要在最后计算结果中乘以causal mask
但如果子块不是矩阵又该如何处理,这里引入一个右对齐的概念,看下图吧:
