这是一个非常好的问题,它触及了现代大模型推理优化的核心。我们来详细拆解一下 torch.compile() 生成的 CUDA Graph 是如何在分布式多卡环境下协同工作的。

首先,要明确一个关键点:torch.compile 和 CUDA Graph 是两种不同但可以协同工作的优化技术。

  1. torch.compile: 它是一个编译器。它捕获你的 PyTorch 模型(一个动态图),将其编译成一个优化的、静态的计算图。这个优化后的图可以由不同的后端执行,其中最强大的后端就是 inductor
  2. CUDA Graph: 它是一种执行模式。它通过捕获在 GPU 上执行的一系列内核(kernel)来创建一个“图”,然后可以多次、高效地重放这个图,避免了 CPU 驱动开销、启动内核的开销等。

torch.compileinductor 后端中,一个非常重要的优化就是自动生成并利用 CUDA Graph 来执行编译好的计算图。

现在,我们来看多卡(分布式)推理的场景。这里通常使用 Tensor Parallelism(张量并行)Pipeline Parallelism(流水线并行)。我们以更细粒度的张量并行为例。

整体工作流程

整个过程的链条是:PyTorch 模型 -> torch.compile 编译 -> 生成一个多卡协同的、优化的静态计算图 -> 图的执行被 CUDA Graph 捕获 -> 部署时,重放多个同步的 CUDA Graph

第1步:模型定义与分布式设置

假设你有一个大模型,比如 Transformer。你使用 torch.distributedtorch.nn.parallel.DistributedDataParallel 或者更底层的通信原语(如 all_reduce)将其切分到多个 GPU 上。

例如,在一个 4 卡的张量并行设置中:

  • 一个大的线性层(如 4096 -> 4096)会被水平切分成 4 个小的矩阵(如 4096 -> 1024),每个 GPU 上一个。
  • 在前向传播中,当需要整个层的输出时,你需要一个 all_reduce 通信操作来聚合所有 GPU 的部分结果。
1
2
3
4
5
6
7
8
# 伪代码,示意张量并行中的线性层
class ColumnParallelLinear(nn.Module):
def forward(self, x):
# x 在每个GPU上都有
local_output = F.linear(x, self.local_weight)
# 需要聚合所有GPU的local_output
global_output = torch.distributed.all_reduce(local_output, op=ReduceOp.SUM)
return global_output

第2步:应用 torch.compile

当你对整个这个分布式的模型包装上 torch.compile 时,魔法开始了。

1
2
3
4
from torch import compile

# model 已经是一个分布式的模型(例如,被 DDP 包装或手动实现了模型并行)
compiled_model = compile(model, backend="inductor")

inductor 后端会做以下事情:

  1. 追踪计算图:它会在你的模型上进行示例运行(比如用一些随机输入),动态地追踪所有发生在 GPU 上的操作。这包括:

    • 计算操作:如矩阵乘法、激活函数等。
    • 通信操作:如 all_reduce, all_gather 等,这些操作在追踪时被表示为 wait_tensor 算子。
  2. 图优化与 lowering:Inductor 会对这个捕获到的全局计算图进行大量优化(如算子融合、缓冲区分配等),然后将其“降低”到最终要运行的 GPU 内核代码。

  3. 自动生成 CUDA Graph:在 inductor 的代码生成阶段,它会为计算图的大部分或关键部分自动生成 CUDA Graph 的捕获和执行逻辑。它并不是为整个多卡系统生成一个巨大的图,而是更智能地:

    • 在每个 GPU 上,Inductor 会生成一个针对该 GPU 本地操作的、高度优化的子图
    • 这个子图包含了计算内核和对通信操作的“等待”

第3步:协同计算与通信——核心所在

这是最关键的部分。在运行时,当你调用 compiled_model(input)

  1. 图捕获:在第一个或前几个“预热”迭代中,系统会实际执行 Inductor 生成的内核。在此期间,CUDA Graph 在每个 GPU 上独立地但同步地捕获属于它自己的那个计算子流。

    • 每个 GPU 的流上都有:计算 -> 通信发起 -> 通信等待 -> 更多计算 ...
    • 由于所有 GPU 都在执行同一个编译好的、静态的图,所以它们的捕获过程是逻辑上同步的。GPU 0 在某个点发起 all_reduce,GPU 1 也在同一个图节点的位置发起 all_reduce
  2. 图重放:预热之后,进入推理部署阶段。对于每一个新的输入,系统不再逐个启动内核,而是命令每个 GPU 重放它之前捕获到的那个 CUDA Graph

    • 通信的隐式同步:因为所有 GPU 都在重放一个在捕获期就确定好的、同步的图序列,所以通信操作自然会被对齐。例如,当重放到 all_reduce 节点时,所有 GPU 都会在图的同一个“位置”发起并等待通信,确保了正确的协同。
    • 效率极致:这种方式达到了极致的效率:
      • 最小化 CPU 开销:CPU 只需要发出一个“启动整个图”的命令,而不是成百上千个启动单个内核的命令。
      • 锁页内存与持久化:通信库(如 NCCL)在 CUDA Graph 模式下可以使用更优的通信策略,比如使用锁页内存以避免额外的数据传输,并将通信操作持久化,进一步减少开销。

技术优势与挑战

优势:

  • 极低的开销:将整个前向传播(包括计算和通信)变成一个“单个操作”,大大降低了框架层面的开销。这对于小批量、高频率的推理请求至关重要。
  • 确定性的执行:CUDA Graph 保证了执行顺序的确定性,避免了动态调度可能带来的性能波动,这对于推理服务的延迟稳定性非常有益。
  • 端到端优化torch.compile + CUDA Graph 实现了从 Python 到底层内核的端到端优化,特别是 Inductor 的自动融合和代码生成能力。

挑战与注意事项:

  • 动态性:CUDA Graph 捕获的是静态的执行流。如果你的模型有动态控制流(如输入依赖的 if-else),这会给捕获带来困难。torch.compile 会尝试处理一些动态情况,但最理想的情况是静态形状和静态图。
  • 内存:CUDA Graph 在捕获期间分配的所有内存(输入、输出、中间结果)在重放期间都会被复用。这意味着它不适合动态改变批量大小的情况,除非为每个不同的配置捕获不同的图。
  • 调试复杂性:一个被编译和 CUDA Graph 捕获的模型,其执行流与原始的 Python 代码相差甚远,调试性能问题或错误会更加复杂。

总结

在分布式多卡推理中,torch.compile()(通过其 inductor 后端)的作用是:

  1. 编译出一个全局的、静态的、包含多卡计算与通信操作的计算图。
  2. 优化这个图,并生成在每个 GPU 上独立运行但逻辑上协同的子图代码
  3. 利用 CUDA Graph 在运行时捕获并重放这些子图。

最终,多卡之间的协同是通过在编译期确定一个静态的、同步的执行计划,并在运行时由所有 GPU 同步重放各自被捕获的 CUDA Graph 来实现的。计算和通信被完美地编织在这些图中,使得整个分布式推理过程像一个精密的钟表一样协同工作,从而实现了极低的延迟和极高的吞吐量。