我来详细解释 torch.cuda.graph(graph, pool=xxx) 的原理与用法。
原理
CUDA Graph 是 PyTorch 中用于优化 GPU 计算性能的技术,主要原理是:
- 计算图捕获:将一系列 CUDA 操作捕获为一个静态图
- 单次启动:整个图只需一次 GPU 启动,而不是多次内核启动
- 减少开销:避免了内核启动开销、CPU-GPU 同步等
pool=xxx 参数用于指定内存池,让图形在专用内存中运行,进一步提高性能。
基本用法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch
def example_computation(x, weight): return torch.nn.functional.conv2d(x, weight)
x = torch.randn(4, 3, 32, 32, device='cuda') weight = torch.randn(16, 3, 3, 3, device='cuda')
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=torch.cuda.graphs.graph_pool_handle()): output = example_computation(x, weight)
graph.replay()
print(output.shape)
|
完整示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| import torch import time
class CUDAGraphModel: def __init__(self): self.model = torch.nn.Sequential( torch.nn.Linear(1024, 512), torch.nn.ReLU(), torch.nn.Linear(512, 256), torch.nn.ReLU(), torch.nn.Linear(256, 10) ).cuda() self.graph = None self.static_input = None self.static_output = None def capture_graph(self, batch_size=32): """捕获计算图""" self.static_input = torch.randn(batch_size, 1024, device='cuda') self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=torch.cuda.graphs.graph_pool_handle()): self.static_output = self.model(self.static_input) def inference(self, input_data): """使用图进行推理""" if self.graph is None: return self.model(input_data) self.static_input.copy_(input_data) self.graph.replay() return self.static_output.clone()
def benchmark(): model = CUDAGraphModel() batch_size = 32 test_input = torch.randn(batch_size, 1024, device='cuda') for _ in range(10): _ = model.model(test_input) torch.cuda.synchronize() start_time = time.time() for _ in range(1000): output = model.model(test_input) torch.cuda.synchronize() normal_time = time.time() - start_time model.capture_graph(batch_size) torch.cuda.synchronize() start_time = time.time() for _ in range(1000): output = model.inference(test_input) torch.cuda.synchronize() graph_time = time.time() - start_time print(f"普通推理时间: {normal_time:.4f}s") print(f"图推理时间: {graph_time:.4f}s") print(f"加速比: {normal_time/graph_time:.2f}x")
if __name__ == "__main__": benchmark()
|
内存池管理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| class GraphMemoryManager: def __init__(self): self.pool = None self.graphs = [] def create_pool(self, size_mb=512): """创建专用内存池""" self.pool = torch.cuda.CUDAGraphPool( torch.cuda.current_device(), size_mb * 1024 * 1024 ) return self.pool def capture_graph(self, model, input_shape): """使用专用内存池捕获图""" graph = torch.cuda.CUDAGraph() static_input = torch.randn(input_shape, device='cuda') static_output = None with torch.cuda.graph(graph, pool=self.pool): static_output = model(static_input) self.graphs.append({ 'graph': graph, 'input': static_input, 'output': static_output }) return len(self.graphs) - 1
manager = GraphMemoryManager() manager.create_pool(1024)
model1 = torch.nn.Linear(1000, 500).cuda() graph_id1 = manager.capture_graph(model1, (32, 1000))
model2 = torch.nn.Conv2d(3, 64, 3).cuda() graph_id2 = manager.capture_graph(model2, (16, 3, 224, 224))
|
注意事项
- 静态输入输出:图的输入输出在捕获后是固定的
- 内存管理:使用
copy_() 更新输入,clone() 获取输出
- 适用场景:适合计算模式固定的推理任务
- 限制:不支持动态控制流、形状变化的操作
最佳实践
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| def optimized_inference_pipeline(): model = YourModel().cuda().eval() graph_dict = {} batch_sizes = [1, 4, 16, 32] for bs in batch_sizes: graph = torch.cuda.CUDAGraph() static_input = torch.randn(bs, *input_shape, device='cuda') with torch.cuda.graph(graph, pool=torch.cuda.graph_pool_handle()): static_output = model(static_input) graph_dict[bs] = { 'graph': graph, 'input': static_input, 'output': static_output } def inference(data): bs = data.shape[0] if bs in graph_dict: graph_info = graph_dict[bs] graph_info['input'].copy_(data) graph_info['graph'].replay() return graph_info['output'].clone() else: return model(data) return inference
|
这种技术特别适合在线服务、实时推理等对延迟敏感的场景。