跳到主要内容

序列并行(Sequence Parallelism)

序列并行(SP)由 Korthikanti et al.(MLSys 2023)提出,在张量并行(TP)组内对序列维度进行切分,与 TP 协同工作。SP 不改变总通信量,但将 TP 中无法 overlap 的 AllReduce 拆分为可以与计算 overlap 的 AllGather + ReduceScatter,提升了 overlap 潜力。


基本原理

在标准 TP 中,LayerNorm 和 Dropout 等算子需要完整的序列数据,因此在这些算子之前必须先将序列合并(AllGather),之后再分散(ReduceScatter)。

SP 将序列维度切分到 TP 组内各节点,在需要完整序列的算子边界插入通信:

[各节点持有 s/N 的序列分片]
→ ReduceScatter(收拢序列分片) ← 实际是 TP 后的 ReduceScatter
→ LayerNorm(完整序列归一化) # 需要完整序列;或分片后分别 LN
→ AllGather(展开完整序列)
→ Attention/MLP(TP 内并行计算)
→ ReduceScatter(代替原来的 AllReduce)
→ [各节点持有 s/N 的序列分片]

SP 与 TP 的替换关系

$\text{TP AllReduce} \;\to\; \text{SP ReduceScatter(前向)} + \text{SP AllGather(反向)}$

从通信原语角度,SP 将 TP 中每次 AllReduce 替换为一次 ReduceScatter 和一次 AllGather。


通信时机

位置通信原语说明
Attention/MLP 计算前AllGather将各节点的序列分片合并为完整序列
Attention/MLP 计算后ReduceScatter将输出归约并重新分片,替代 AllReduce

每个 Transformer 层的 SP 通信次数与 TP 相同(每层各 2 次 AG 和 RS),但通信的调度位置不同,与计算的关系更松耦合。


消息大小

SP 的 AllGather 和 ReduceScatter 消息大小与 TP AllReduce 相同:

$M_{\text{SP}} = b \times s \times h \times \text{dtype\_size}$

但两者的数据流向不同:

  • AllGather:每个节点持有 $M/N$ 的分片,AllGather 后每个节点得到完整 $M$
  • ReduceScatter:每个节点持有完整 $M$ 的部分积,ReduceScatter 后每个节点得到 $M/N$ 的归约结果

两个操作的通信量各为 $\frac{N-1}{N} \cdot M$,合计等于 AllReduce 的 $\frac{2(N-1)}{N} \cdot M$

SP 通信特征汇总

特征
通信原语AllGather + ReduceScatter
消息大小10 ~ 100 MB
通信组大小与 TP 相同(2 ~ 8)
频率每层各 1 次(AG 和 RS 各一次)
延迟敏感性中(有 overlap 机会)

与张量并行的协同

总通信量不变

TPSP
通信原语AllReduceAllGather + ReduceScatter
总通信量$\frac{2(N-1)}{N} M$$\frac{N-1}{N} M + \frac{N-1}{N} M = \frac{2(N-1)}{N} M$

两者总通信量完全相同,SP 不是通过减少通信量来提升性能。

Overlap 潜力的差异

TPSP
通信与计算的关系AllReduce 在计算关键路径上,无法 overlapAllGather 可在前一层计算完成前预取,ReduceScatter 可与下一层 LayerNorm 重叠
Overlap 潜力

SP 的优势在于:AllGather 在 Attention/MLP 计算之前触发,可以提前发出,与前一层计算重叠;ReduceScatter 在 Attention/MLP 计算之后触发,可以与 LayerNorm 计算重叠。

内存优化

SP 同时减少了激活值的峰值显存:在 TP 中,每个节点需要保存完整序列长度的激活值;在 SP 中,每个节点只保存 $s/N$ 的序列分片,激活值显存减少 $N$ 倍(对 SP 部分的层)。


参考文献