我来详细解释 torch.cuda.graph(graph, pool=xxx) 的原理与用法。

原理

CUDA Graph 是 PyTorch 中用于优化 GPU 计算性能的技术,主要原理是:

  1. 计算图捕获:将一系列 CUDA 操作捕获为一个静态图
  2. 单次启动:整个图只需一次 GPU 启动,而不是多次内核启动
  3. 减少开销:避免了内核启动开销、CPU-GPU 同步等

pool=xxx 参数用于指定内存池,让图形在专用内存中运行,进一步提高性能。

基本用法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

# 准备示例计算
def example_computation(x, weight):
return torch.nn.functional.conv2d(x, weight)

# 创建输入和参数
x = torch.randn(4, 3, 32, 32, device='cuda')
weight = torch.randn(16, 3, 3, 3, device='cuda')

# 创建 CUDA 图
graph = torch.cuda.CUDAGraph()

# 使用内存池捕获计算图
with torch.cuda.graph(graph, pool=torch.cuda.graphs.graph_pool_handle()):
output = example_computation(x, weight)

# 执行图(极低开销)
graph.replay()

print(output.shape) # torch.Size([4, 16, 30, 30])

完整示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import time

class CUDAGraphModel:
def __init__(self):
self.model = torch.nn.Sequential(
torch.nn.Linear(1024, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 10)
).cuda()

self.graph = None
self.static_input = None
self.static_output = None

def capture_graph(self, batch_size=32):
"""捕获计算图"""
# 创建静态输入
self.static_input = torch.randn(batch_size, 1024, device='cuda')

# 创建 CUDA 图
self.graph = torch.cuda.CUDAGraph()

# 使用专用内存池捕获
with torch.cuda.graph(self.graph,
pool=torch.cuda.graphs.graph_pool_handle()):
self.static_output = self.model(self.static_input)

def inference(self, input_data):
"""使用图进行推理"""
if self.graph is None:
return self.model(input_data)

# 将数据复制到静态输入缓冲区
self.static_input.copy_(input_data)

# 重放图(极低开销)
self.graph.replay()

return self.static_output.clone()

# 性能对比测试
def benchmark():
model = CUDAGraphModel()
batch_size = 32

# 预热
test_input = torch.randn(batch_size, 1024, device='cuda')
for _ in range(10):
_ = model.model(test_input)

# 普通推理基准
torch.cuda.synchronize()
start_time = time.time()
for _ in range(1000):
output = model.model(test_input)
torch.cuda.synchronize()
normal_time = time.time() - start_time

# 图推理基准
model.capture_graph(batch_size)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(1000):
output = model.inference(test_input)
torch.cuda.synchronize()
graph_time = time.time() - start_time

print(f"普通推理时间: {normal_time:.4f}s")
print(f"图推理时间: {graph_time:.4f}s")
print(f"加速比: {normal_time/graph_time:.2f}x")

if __name__ == "__main__":
benchmark()

内存池管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 自定义内存池
class GraphMemoryManager:
def __init__(self):
self.pool = None
self.graphs = []

def create_pool(self, size_mb=512):
"""创建专用内存池"""
self.pool = torch.cuda.CUDAGraphPool(
torch.cuda.current_device(),
size_mb * 1024 * 1024 # 转换为字节
)
return self.pool

def capture_graph(self, model, input_shape):
"""使用专用内存池捕获图"""
graph = torch.cuda.CUDAGraph()
static_input = torch.randn(input_shape, device='cuda')
static_output = None

with torch.cuda.graph(graph, pool=self.pool):
static_output = model(static_input)

self.graphs.append({
'graph': graph,
'input': static_input,
'output': static_output
})
return len(self.graphs) - 1

# 使用示例
manager = GraphMemoryManager()
manager.create_pool(1024) # 1GB 内存池

model1 = torch.nn.Linear(1000, 500).cuda()
graph_id1 = manager.capture_graph(model1, (32, 1000))

model2 = torch.nn.Conv2d(3, 64, 3).cuda()
graph_id2 = manager.capture_graph(model2, (16, 3, 224, 224))

注意事项

  1. 静态输入输出:图的输入输出在捕获后是固定的
  2. 内存管理:使用 copy_() 更新输入,clone() 获取输出
  3. 适用场景:适合计算模式固定的推理任务
  4. 限制:不支持动态控制流、形状变化的操作

最佳实践

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def optimized_inference_pipeline():
# 1. 创建模型和数据加载器
model = YourModel().cuda().eval()

# 2. 为每个批次大小创建专用图
graph_dict = {}
batch_sizes = [1, 4, 16, 32]

for bs in batch_sizes:
graph = torch.cuda.CUDAGraph()
static_input = torch.randn(bs, *input_shape, device='cuda')

with torch.cuda.graph(graph, pool=torch.cuda.graph_pool_handle()):
static_output = model(static_input)

graph_dict[bs] = {
'graph': graph,
'input': static_input,
'output': static_output
}

# 3. 根据输入批次选择对应的图执行
def inference(data):
bs = data.shape[0]
if bs in graph_dict:
graph_info = graph_dict[bs]
graph_info['input'].copy_(data)
graph_info['graph'].replay()
return graph_info['output'].clone()
else:
# 回退到普通模式
return model(data)

return inference

这种技术特别适合在线服务、实时推理等对延迟敏感的场景。