从Mistral 7B到MoE模型Mixtral 8x7B的全面解析:从原理分析到代码解读
前言
第一部分 23年5月Mistral AI发布的Mistral 7B
1.1 Mistral 7B:通过分组查询注意力 + 滑动窗口注意力超越13B模型
1.1.1 Mistral 7B:超过llama2 13B、GQA、SWA、RoPE
- Mistral 7B在所有评估基准中均胜过了目前最好的13B参数模型(Llama 2,对标的第二代),并在推理、数学和代码生成方面超越了Llama 34B(对,这里其对标Llama第一代的34B,原因是当时Llama 2 34B 尚未发布)
Mistral 7B outperforms the previous best 13B model (Llama 2, [Llama 2: Open foundation and fine-tuned chat models]) across all testedbenchmarks, and surpasses the best 34B model (LLaMa 34B, [Llama: Open and efficient foundation language models]) in mathematics and codegeneration. - 该模型采用了分组查询注意力(GQA),GQA显著加快了推理速度,还减少了解码期间的内存需求,允许更高的批处理大小,从而提高吞吐量
GQA significantly accelerates the inference speed, and also reduces the memory requirement during decoding, allowing for higher batch sizes hence higher throughput
关于GQA的更多介绍,请参见《一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA》 - 同时结合滑动窗口注意力(sliding window attention,简称SWA)以有效处理任意长度的序列
SWA is designed to handle longer sequences more effectively at a reduced computational cost
当然,SWA也不是Mistral的首创,而是基于这两篇论文实现的:Generating Long Sequences with Sparse Transformers、Longformer: The Long-Document Transformer
具体而言,你再看上上张图所示的「模型参数图」,可知context_len 8192是说它训练的时候,传进来的数据最大只能到8192个tokens,也就是训练时的上下文长度上限
windows_size 4096是sliding windows attention的滑窗大小,1次attention计算的上下文范围只4096个tokens
言外之意是,每个token只最多计算4096的范围
第5000个token只计算[905: 5000]这个范围的attention
第5001个token只计算[906: 5001]这个范围的attention
以此类推.. - 位置编码方面,和llama统一用的RoPE(顺带插一嘴,包括后来Google开源的gemma也用的RoPE,所以RoPE算是标配了,至于关于位置编码和RoPE的详尽细致的介绍,请参见此文)
RoPE所对应的代码如下所示(代码来源:mistral-src/mistral /rope.py)
1.1.2 Mistral 7B-Instruct
1.2 Mistral 7B更多细节:滑动窗口注意力、滚动缓冲区缓存、预填充与分块
1.2.1 滑动窗口注意力:扩展上下文长度
- 每个token最多可以关注来自上一层的W个token(上图中,W = 3)。请注意,滑动窗口之外的token仍然影响下一个单词预测
each token can attend to at most W tokens from the previous layer (here, W = 3). Note that tokensoutside the sliding window still influence next word prediction.
举个例子,在面对这个序列时:The cat sat on the
如果是标准注意力,在计算最后一个token “the”时,得计算the本身所对应的query与整个上文每个token对应的key的内积,当序列长度一长时,该计算量还是比较大的
但如果是滑动窗口注意力,则在计算最后一个token “the”时,只需计算the本身所对应的query与上文中3个token对应的key的内积(这里说的上文中的3个token 包括the自己在内) - 在每个注意力层,信息可以向前移动W个token。因此,在k层注意力之后,信息最多可以向前移动k个×W个token
At each attention layer, information can moveforward by W tokens. Hence, after k attention layers, information can move forward by up to k ×W tokens.
1.2.2 滚动缓冲区缓存(Rolling Buffer Cache)
- 缓存的大小是固定的W,时间步长i的键和值存储在缓存的位置i mod W中。因此,当位置i大于W时,缓存中过去的值就会被覆盖,缓存的大小就会停止增加
The cache has a fixed size of W, and the keys and values for the timestep i are storedin position i mod W of the cache. As a result, when the position i is larger than W, past valuesin the cache are overwritten, and the size of the cache stops increasing
以“The cat sat on the mat”为例..
当 i = 0 时,指The,0 mod 3=0
当 i = 1 时,指cat,1 mod 3=1
当 i = 2 时,指sat,2 mod 3=2
当 i = 3 时,指on,3 mod 3=0
当 i = 4 时,指the,4 mod 3=1
当 i = 5 时,指mat,5 mod 3 = 2 - 在32k token的序列长度上,这减少了8倍的缓存内存使用,而不影响模型质量
On a sequence length of 32k tokens, this reduces the cache memory usageby 8x, without impacting the model quality.
1.2.3 预填充与分块:减少重复运算
- 如果prompt非常大,可以把它分成更小的块,用每个块预填充缓存。为此,可以选择窗口大小作为分块大小。因此,对于每个块,需要计算缓存和块上的注意力
- 下图展示了注意力掩码在缓存和分块上的工作原理在预填充缓存时,长序列被分块,以限制内存使用
我们把一个序列分成三个块来处理,“The cat sat on”,“the mat and saw”,“the dog go to”。上图中显示了第三块(“the dog go to”)发生的情况:它使用因果掩码(最右块)来关注自己,使用滑动窗口(中心块)来关注缓存,并且不关注过去的token,因为它们在滑动窗口之外(左块)
第二部分 首个开源MoE大模型Mixtral 8x7B
2.1 Mixtral 8x7B的整体架构与模型细节
- 8 个专家总数,而不是 16 名(减少一半)
- 每个专家为 7B 参数,而不是 166B(减少 24 倍)
- 47B 总参数(估计)而不是 1.8T(减少 42 倍)
- 与原始 GPT-4 相同的 32K 上下文
- 今年10月发布的Mistral 7B
- 今年12月则发布的混合专家模型,称之为Mixtral 8x7B
- 其中前馈块从一组 8 个不同的参数组中进行选择(It is a decoder-only model where the feedforward block picks from a set of 8 distinct groups of parameters)
- 在每一层,对于每个token,路由器网络选择其中的两个组(“专家”)来处理token并通过组合相加得到它们的输出(At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively)
这点可能很多朋友不会特别在意,但你仔细品味下,你会发现大有天地,即:每个token 都由某两个专家负责完成,最后整个序列 则是由一系列「不同的两两专家」组合完成,下文还会详述该点 - 上下文长度达到32K
Mixtral is pretrained with multilingual data using a context size of 32k tokens
2.1.1 Mixtral 8x7B是一个稀疏的专家混合网络
- 表示第 个专家的门控网络的n维输出(denotes the n-dimensional output of the gating network for the i-th expert)
- 是第个专家网络的输出(the output of the i-th expert network)
- 如果在logits的top-K坐标中,则,否则
where if is among the top-K coordinates of logits and otherwise. - 每个token所使用的专家数量是可调的参数
当保持不变但增加时,可以增加模型的总参数数量,同时保持计算成本有效不变
The value of K – the number of experts used per token – is a hyper-parameter that modulates the amount of compute used to process each token. If one increases while keeping fixed, one can increase the model’s parameter count while keeping its computational cost effectively constant.
这引出了「总参数数量(通常称为稀疏参数数量)」与用于「处理单个token的活动参数数量」之间的区别
对总参数数量而言,随着的增加而增加;而对于活动参数数量而言,直到逐渐增加
This motivates a distinction between the model’s total parameter count (commonly referenced as the sparse parameter count), which grows with n, and the number of parameters used for processing an individual token (called the active parameter count), which grows with K up to n.
- 例如,Megablocks将MoE层的前馈网络(FFN)操作转换为大型稀疏矩阵乘法(Megablocks [13] casts the feed-forward network (FFN) operations of the MoE layer as large sparse matrix multiplications),从而显著提升了执行速度
并且可以自动处理不同专家被分配可变数量token的情况(naturally handling cases where different experts get a variable number of tokens assigned to them.)- 此外,通过标准模型并行技术和一种名为专家并行(EP)的特殊分区策略,MoE层可以在多个GPU上进行分布
Moreover, the MoE layer can be distributed to multiple GPUs through standard Model Parallelism techniques, and through a particular kind of partitioning strategy called Expert Parallelism (EP) [28].
在MoE层执行过程中,旨在由特定专家处理的token会被路由到相应的GPU进行处理,并将专家输出返回到原始token位置During the MoE layer’s execution, tokens meant to be processed by a specific expert are routed to the corresponding GPU for processing, and the expert’s output is returned to the original token location.
需要注意的是,在负载平衡方面,EP带来了挑战,因为均匀地分配工作负载至关重要以避免单个GPU过载或遇到计算瓶颈
Note that EP introduces challenges in load balancing, as it is essential to distribute the workload evenly across the GPUs to prevent overloading individual GPUs or hitting computational bottlenecks.
- 采用与专家函数相同的SwiGLU架构,并设置K = 2
- 这意味着每个token被路由到两个具有不同权重集的SwiGLU子块
For Mixtral we use the same SwiGLU architecture as the expert function Ei(x) and set K = 2
2.1.2 Mixtral的参数总量为何是46.7B而非56B
- 即,虽然Mixtral模型的完整名称为“Mixtral-8x7B-v0.1”,看似有“8x7B=56B”的参数量,但实际的参数量应当是约47B而非56B,因为在各个层中仅有experts部分(FFN)是独立存在的,其余的部分(Attention等)则是各个expert均有共享的
- 可以想象成一个“纺锤状”的样式,数据由共享模块传输至expert模块对应于纺锤中部发散的部分,对expert的输出进行加权聚合则对应纺锤末端收束的部分
2.1.3 Mixtral中所采取的GQA机制
2.1.4 Mixtral中的路由(Gating/Router)
- Sentence-Level是对各个样本分别进行路由
- Token-Level是对样本中的各个token分别进行路由
- Task-Level要求不同的expert明确负责不同任务
- 至于首次在NLP任务中使用Token-Level的MOE可以追溯至2017年的《Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer》
- 该论文展示了Token-Level的一些有趣现象,通过观察各个expert所负责token的统计特征,不同的expert确实掌握了一些语法层面理解, 当需要不定冠词“a”在重要的动词短语中引入直接宾语时,则会有专门的752号expert来负责输出这个“a”
2.2 模型表现:匹配或超越Llama 2 70B 以及 GPT3.5
2.3 指令遵循模型Mixtral 8x7B Instruct
第三部分 Mixtral(MOE架构)的实现细节:代码解读
- 对于每个token,路由器网络选择其中的两个组(“专家”)来处理token并通过组合相加得到它们的输出「At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively」
- 啥意思,就是如果不仔细了解的话,很容易误以为是“输入的一整个序列”分给TOP 2专家,结果事实是每个token都各自分配TOP 2专家,而且当你仔细抠完mixtral的代码之后,你会发现还真是如此..
3.1 MOE模块的前向传播:整体流程
3.1.1 获取各token对应的top2 expert及其权重
- 由于hidden_states的维度,通常包括批大小(batch_size)、序列长度(sequence_length)和隐藏层维度(hidden_dim),故有
- 将hidden_states的形状重构为一个二维张量,用于将其处理为每个token的表示
- 通过一个门控(gate)机制来生成路由逻辑(router_logits),用于后续决定每个token应由哪些专家(experts)处理
- 对每个token的路由逻辑应用softmax函数,计算每个专家对每个token的处理权重
- 选取每个token的前top_k个最重要的专家及其权重
- 对选出的每个token的专家权重进行归一化处理,确保每个token的专家权重之和为1
3.1.2 将各token传入对应的expert模型中进行前向传播得到输出
- 首先
- 根据给定的selected_experts作为元素1所在位置的索引,构建向量长度为num_experts的one-hot编码
好比24个token,需要由8个expert两两组合处理,那我针对每一个token都构建长度为8的0 1编码,这个编码分别代表8个expert
故,每个token选择了哪两个expert,则对应的编码位上变为1,否则为0
比如July这个token选择3 7两个expert,则July对应的0 1编码位:0 0 1 0 0 0 1 0
再比如Edu这个token如果选择了2 4两个expert,则其01编码为:0 1 0 1 0 0 0 0
依此类推.. - 使用相对取巧方法来进行前向传播
torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0).shape: (num_experts, topk, bs*seq_len)
的物理含义是由“每个token分别选取了哪topk个expert”变成了“每个expert分别作为各个排位存在的时候,对应需要处理哪些token”
这样做的好处在于:后续循环的时候只需要进行num_experts次前向传播就能得到结果,而无需进行bs*seq_len次前向传播
为方便大家更好的理解上面那行代码的含义,我特地画了个示意图以加快理解
A B C D E F G H I J K L M N O P Q R S T U V W X Y Z,是需要处理的token
1 2 3 4 5 6 7 8,代表8个expert
(如阿荀所说,如此,便把关注视角从“各个token”变成了“各个专家”,当然,大部分情况下 token数远远不止下图这5个,而是比专家数多很多。总之,这么一转换,最终可以省掉很多循环 ) 具体而言,下面这个张量 - 所以接下来只需要进行num_experts次循环
由于expert_mask记录有各个expert分别作为各个排位存在的时候,对应需要处理哪些token,故expert_mask[expert_idx].shape: (topk, bs*seq_len),便是从expert_mask中取出其对应的,详见下文的【代码块B】
故上面三行的最后一行中等式中的右边项:torch.where(expert_mask[expert_idx]),则是辨析出expert_mask[expert_idx]值为1的位置索引,详见下文的【代码块C】
至于:idx.shape: (bs * seq_len, ),则代表expert_mask[expert_idx]中(每列)元素值为1的索引位置
以及:top_x.shape: (bs * seq_len, ),则代表expert_mask[expert_idx]中(每行)元素值为1的索引位置
继续分析该for循环之后的代码,如下 上面这几行代码得好好解释下 - for循环结束后,相当于所有expert均处理完毕后,将维护好的final_hidden_states由(bs * seq_len, hidden_dim)转为(bs, seq_len, hidden_dim),并将作为本批次运行的返回
更多详见下文的【代码块E】
3.2 MOE前向传播中五个代码块的细致分析:鞭辟入里
3.2.1 代码块A:routing_weights的具体样例
3.2.2 代码块B:expert_mask[expert_idx]
3.2.3 代码块C:idx, top_x = torch.where(expert_mask[expert_idx])
- 因此top_x将作为索引用于从全部token的隐向量hidden_states中取出对应token的隐向量
- 而idx和top_x也会组合起来被用于从expert权重张量routing_weights中取出对应的权重
3.2.4 代码块D:expert内部的前向传播
3.2.5 代码块E:final_hidden_states
- 最初final_hidden_states是全0张量
- 使用.index_add_函数后在指定位置(top_x)加上了指定值(current_hidden_states)
- 再次查看与当前expert有关的final_hidden_states部分,即
第四部分 混合专家模型MOE的发展史与更多实践细节
第五部分 MoE-Mamba模型:将 Mamba 和混合专家层组合起来
参考文献与推荐阅读
- 一条磁力链接席卷AI圈,87GB种子直接开源8x7B MoE模型
- Mistral AI对Mixtral of experts的介绍:Mixtral of experts | Mistral AI | Open source models
- 开源大模型超越GPT-3.5!爆火MoE实测结果出炉
- https://github.com/nateraw/replicate-examples/tree/main/mixtral
- 预训练大模型:百度UFO(Unified Feature Optimization)
- 集4学员且友人wstart推荐的三篇论文
LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment
MegaBlocks: Efficient Sparse Training with Mixture-of-Experts
Weak-to-Strong Generalization: Eliciting Strong Capabilities With Weak Supervision - Mixtral 8x7B论文终于来了:架构细节、参数量首次曝光
一条磁力链爆全网,Mixtral 8x7B论文来了!碾压Llama 2 70B,每token仅需激活13B参数 - Mixtral of Experts论文,是本文中此节“1.1.1 Mixtral 8x7B是一个稀疏的专家混合网络”的核心参考
- 图解Mixtral 8 * 7b推理优化原理与源码实现
0 Comments