大模型分布式并行
本文介绍了目前大模型分布式并行的几种方式,包括流水并行、张量并行、数据并行,并介绍了相关的里程碑工作。
参数服务器
参数服务器简单来讲就是用一个集中的或分布的服务器来管理统一的权重参数,在训练过程中将梯度聚合到参数服务器,将参数更新后再广播到各个计算结点中去。
文章发表于2014年:链接:https://www.usenix.org/system/files/conference/osdi14/osdi14-paper-li_mu.pdf
文章中有这样一张图,这样图很好的体现了参数服务器的架构和能力。
- server group:服务组;
- server manager: 与resource manager通信,告知资源需求,并管理server node,控制server node管理哪些参数;
- server node:实际管理参数的结点;
- worker group:计算组;
- task scheduler:用来控制任务的调度;
- work node:进行实际的计算任务;
技术之外,胡扯几句,参数服务器厉害的地方不仅在于它对于分布式在AI领域应用的优化,还在于它的前瞻性。在2014年前后,深度学习远没有如今这么火爆,更是没有如今大模型遍地开花的现象。当时的模型几乎都是稠密的,单卡就可以完成训练,很少会有分布式的需求,就在这样一个背景之下,李沐老师做出了如此前瞻性的工作。参数服务器放到今天依然是不过时的,李沐老师也曾说过,这属于是“风水轮流转”,转到了AI领域对分布式能力需求如此之大的今天。
并行方式
并行的方式大致可以分为两大类,数据并行和模型并行。
数据并行
每个设备上都有完整的网络及其权重参数,将模型输入拆分,每个拆分部分都在不同的设备上进行计算,最后再将所有设备上的梯度相加。
- 优点:简单易行,适用于大规模数据训练。
- 缺点:每个设备都必须存储整个网络及其参数,同时在更新梯度时需要在设备之间通信整个网络的参数。
模型并行
每个设备上只有部分的网络及其权重参数,将模型本身拆分,每个拆分部分都在不同的设备上进行计算,每层计算完成后在设备间同步计算结果。
- 优点:每个设备只需存储部分网络及其参数,存储压力较小。
- 缺点:每层计算完成都需要进行通信,并且有些情况无法均匀的拆分模型。
模型并行
流水并行
流水并行的一个里程碑工作就是Gpipe。
文章发表于2018年,链接:https://arxiv.org/abs/1811.06965
以文章中的图片为例,图(a)是一个神经网络,由一系列连续的层构成。
图(b)是传统的模型并行所做的工作,即把不同的层放在不同的设备上进行计算。由于正反向传播之间是有数据依赖的,所以可以观察到整体耗时其实和单设备计算耗时没有区别,甚至可能因为卡间通信等原因导致性能更差。这样做仅有的好处就是,对每张卡的显存需求更低了,可以放下更大的模型。
图(c)就是Gpipe的工作,它将每个小批量进一步切分成微批量,提高了GPU的计算利用率,理论上来说,微批量切分的越多,流水线中的bubble会越小。
Gpipe的思想可以理解为数据并行+模型并行。
除了上面的基本思想之外,Gpipe还对内存使用进行了优化。
传统的模型计算过程需要将所有的中间结果保留下来,作为反向传播运算的输入。但这部分中间结果很占空间,它与输入、隐藏层宽度和模型层数相关。Gpipe在这里基于微批量使用了一种优化,即只保留微批量的输入和最后一个微批量的中间结果。这样的代价就是每个微批量需要计算两次的forward过程。
张量并行
张量并行的一个里程碑工作是Megatron。
文章发表于2019年,链接:https://arxiv.org/abs/1909.08053
Megatron的优化工作主要是针对于Transformer架构的语言模型,它主要讨论了两个关键组建的模型拆分逻辑,也是Transformer中最重要的两个部分,即Self-Attention和MLP层。
对于MLP层,公式可以看作
这样做使得每个设备都只完整的输入输出和一部分中间结果,存储压力较小。这样做的代价是,模型的每个层计算完成后都需要一次All-reduce来完成通信,同时由于层与层之间存在数据依赖,故All-reduce无法异步执行。还有一个小的缺点就是,Self-Attention头的个数和隐藏层大小必须能被设备的数量整除。
数据并行
数据并行的一个里程碑工作是ZeRO。
文章发表于2019年,链接:https://arxiv.org/abs/1910.02054
ZeRO的思想和参数服务器有些类似,由于它是数据并行,所以不对模型进行拆分,但模型较大的时候依然会面临单个设备放不下的问题,所以ZeRO提出在计算的时候,将涉及计算的部分模型加载到设备上,从而解决模型规模较大带来的问题。
文章主要提到两个优化,ZeRO-DP即数据并行优化,以及ZeRO-R即设备内存优化。
ZeRO-DP又分为三个方面,分别是
:每个设备只存储一部分高精度(如fp32)优化器状态; :每个设备只存储一部分低精度(如fp16)梯度; :每个设备只存储一部分低精度(如fp16)权重参数;
在工程实现中:
- ZeRO1:
; - ZeRO2:
+ ; - ZeRO3:
+ + ;
ZeRO-R也分为三方面,分别是
:每个设备只存储一部分输入; :每个设备有一个通信Buffer,数据量达到一定程度后才进行通信; :内存碎片整理;
这里的
和 其实都比较偏向特定应用(PyTorch),都是在上层的实现,一般性不是很好。