PyTorch模型性能分析与优化 — 第6部分
PyTorch模型性能分析与优化 — 第6部分' can be condensed as 'PyTorch模型性能分析与优化-第6部分'.
如何使用PyTorch Profiler、PyTorch Hooks和TensorBoard识别和分析反向传播中的性能问题
这是我们关于使用PyTorch Profiler和TensorBoard分析和优化PyTorch模型的系列文章中的第六部分。在本文中,我们将解决一种较为复杂的性能问题类型——训练步骤中反向传播的瓶颈。我们将解释这种瓶颈的特殊之处,并提出一种使用PyTorch内置支持的钩子来分析它的方法。非常感谢Yitzhak Levi对本文的贡献。
玩具模型
为了便于讨论,我们使用流行的timm python模块(版本0.9.7)定义了一个简单的Vision Transformer(ViT)分类模型。我们将模型的patch_drop_rate标志设置为0.5,这会导致模型在每个训练步骤中随机丢弃一半的补丁。训练脚本编程以最小化不确定性,使用torch.use_deterministic_algorithms函数和cuBLAS环境变量CUBLAS_WORKSPACE_CONFIG。请参阅下面的代码块以获取完整的模型定义:
import torch, time, osimport torch.optimimport torch.profilerimport torch.utils.datafrom timm.models.vision_transformer import VisionTransformerfrom torch.utils.data import Dataset# 使用GPUdevice = torch.device("cuda:0")# 配置PyTorch使用可复现的算法torch.manual_seed(0)os.environ[ "CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"torch.use_deterministic_algorithms(True)# 定义ViT支持的分类模型model = VisionTransformer(patch_drop_rate=0.5).cuda(device)# 定义损失函数loss_fn = torch.nn.CrossEntropyLoss()# 定义训练优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 使用随机数据class FakeDataset(Dataset): def __len__(self): return 1000000 def __getitem__(self, index): rand_image = torch.randn([3, 224, 224], dtype=torch.float32) label = torch.tensor(data=[index % 1000], dtype=torch.int64) return rand_image, labeltrain_set = FakeDataset()train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, num_workers=8, pin_memory=True)t0 = time.perf_counter()summ = 0count = 0model.train()# 使用分析器对象包装的训练循环with torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/perf')) as prof: for step, data in enumerate(train_loader): inputs = data[0].to(device=device, non_blocking=True) label = data[1].squeeze(-1).to(device=device, non_blocking=True) with torch.profiler.record_function('forward'): outputs = model(inputs) loss = loss_fn(outputs, label) optimizer.zero_grad(set_to_none=True) with torch.profiler.record_function('backward'): loss.backward() with torch.profiler.record_function('optimizer_step'): optimizer.step() prof.step() batch_time = time.perf_counter() - t0 if step > 1: # 跳过第一步 summ += batch_time count += 1 t0 = time.perf_counter() if step > 500: break print(f'average step time: {summ/count}')
我们将在Amazon EC2 g5.2xlarge实例上运行实验(包含NVIDIA A10G GPU和8个vCPU),并使用官方的AWS PyTorch 2.0 Docker镜像。
初始性能结果
在下面的图像中,我们捕捉了在TensorBoard插件Trace View中显示的性能结果:
- 在生成式人工智能时代重新思考质量保证
- “令人难以置信的虚拟化:梅赛德斯-奔驰准备使用NVIDIA Omniverse、MB.OS和生成式人工智能为下一代平台打造数字化生产系统”
- 亚马逊将无人收银技术应用于服装店

在训练步骤的前向传播中,操作被集中在顶部线程中,但在后向传播中出现了性能问题。我们可以看到,单个操作 GatherBackward 占据了跟踪的很大一部分。仔细观察,我们可以看到其中包括 “to”、”copy_” 和 “cudaStreamSynchronize” 等基础操作。正如我们在系列的第二部分中所看到的,这些操作通常表示数据正在从主机复制到设备上,这是我们在训练步骤中希望避免的。
此时你自然会问:为什么会发生这种情况?我们模型定义的哪一部分引起了这个问题?GatherBackward 跟踪表明 torch.gather 操作可能涉及其中,但它是从哪里来的,为什么它会引起同步事件?
在我们之前的帖子中(例如,在这里),我们主张使用带有标签的 torch.profiler.record_function 上下文管理器来确定性能问题的来源。问题在于,性能问题发生在我们无法控制的后向传播中!特别是,我们无法使用上下文管理器将后向传播中的各个操作包装起来。理论上,可以通过深入分析跟踪视图,并将后向传播中的每个段与前向传播中的相应操作进行匹配,来确定有问题的模型操作。然而,这不仅可能非常繁琐,而且还需要对模型训练步骤的所有低级操作有深入了解。使用 torch.profiler.record_function 标签的优点是可以轻松找到我们模型中有问题的部分。理想情况下,即使在后向传播中出现性能问题,我们也希望保留同样的能力。在下一节中,我们将介绍如何使用 PyTorch 钩子来实现这一目标。
使用 PyTorch 后向传播钩子进行性能分析
虽然 PyTorch 不允许您包装单个后向传播操作,但它允许您使用钩子支持来添加自定义功能。PyTorch 支持向 torch.Tensors 和 torch.nn.Modules 注册钩子。虽然我们在本文中提出的技术将依赖于向模块注册后向钩子,但可以类似地使用张量钩子注册来替换或扩充基于模块的方法。
在下面的代码块中,我们定义了一个包装函数,该函数接受一个模块并注册一个 full_backward_hook 和一个 full_backward_pre_hook(实际上只需要一个即可)。每个钩子被编程为使用 torch.profiler.record_function 函数向捕获的性能跟踪添加一条消息。backward_pre_hook 被编程为打印一个 “before” 消息,backward_hook 则是一个 “after” 消息。附加的 details 字符串用于区分同一模块类型的多个实例。
def backward_hook_wrapper(module, details=None): # 定义 register_full_backward_pre_hook 函数 def bwd_pre_hook_print(self, output): message = f'{module.__class__.__qualname__} 的后向传播之前' if details: message = f'{message}: {details}' with torch.profiler.record_function(message): return output # 定义 register_full_backward_hook 函数 def bwd_hook_print(self, input, output): message = f'{module.__class__.__qualname__} 的后向传播之后' if details: message = f'{message}: {details}' with torch.profiler.record_function(message): return input # 注册钩子 module.register_full_backward_pre_hook(bwd_pre_hook_print) module.register_full_backward_hook(bwd_hook_print) return module
使用 backward_hook_wrapper 函数,我们可以开始定位性能问题的来源。我们首先像下面的代码块中所示,将模型和损失函数包装起来:
model = backward_hook_wrapper(model)loss_fn = backward_hook_wrapper(loss_fn)
使用 TensorBoard 插件 Trace View 的搜索框,我们可以确定 “before” 和 “after” 消息的位置,并推断出模型和损失的反向传播开始和结束的位置。这使我们能够得出结论:性能问题发生在模型的后向传播中。下一步是使用 backward_hook_wrapper 函数包装 Vision Tranformer 的内部模块:
model.patch_embed = backward_hook_wrapper(model.patch_embed)model.pos_drop = backward_hook_wrapper(model.pos_drop)model.patch_drop = backward_hook_wrapper(model.patch_drop)model.norm_pre = backward_hook_wrapper(model.norm_pre)model.blocks = backward_hook_wrapper(model.blocks)model.norm = backward_hook_wrapper(model.norm)model.fc_norm = backward_hook_wrapper(model.fc_norm)model.head_drop = backward_hook_wrapper(model.head_drop)
在上面的代码块中,我们指定了每个内部模块。另一种包装模型所有一级模块的方法是迭代其命名子模块:
for submodule in model.named_children(): submodule = backward_hook_wrapper(submodule)
下面的图片捕捉到了“在问题的反向操作之前的PatchDropout”消息出现在有问题的GatherBackward操作之前:

我们的分析表明,性能问题的源头是PathDropout模块。检查模块的前向函数,我们确实可以看到调用了torch.gather。
对于我们的示例模型,我们只需要两次分析迭代就能找到性能问题的源头。实际上,可能需要进行更多次这种方法的迭代。
请注意,PyTorch包含了torch.nn.modules.module.register_module_full_backward_hook函数,可以在训练步骤的所有模块上附加一个钩子,以进行单次调用。尽管在简单的情况下(比如我们的示例),这可能足够了,但它无法区分相同模块类型的不同实例。
现在我们知道了性能问题的源头,可以开始修复它。
优化建议:尽可能使用索引而不是gather
既然我们知道问题的来源是DropPatches模块中的torch.gather操作,我们可以研究什么可能触发了这个耗时的主机-设备同步事件。我们的调查将我们带回到torch.use_deterministic_algorithms函数的文档,该函数告诉我们,当在需要梯度的CUDA张量上调用时,torch.gather会表现出非确定性行为,除非torch.use_deterministic_algorithms被调用时mode设置为True。换句话说,通过配置我们的脚本使用确定性算法,我们修改了torch.gather反向传递的默认行为。事实证明,正是这个改变导致了同步事件的需要。确实,如果我们删除这个配置,性能问题就消失了!问题是,我们能否在不付出性能代价的情况下保持算法的确定性。
在下面的代码块中,我们提出了一种替代的PathDropout模块forward函数的实现,它使用torch.Tensor索引而不是torch.gather来产生相同的输出。被修改的代码行已经被突出显示。
from timm.layers import PatchDropoutclass MyPatchDropout(PatchDropout): def forward(self, x): prefix_tokens = x[:, :self.num_prefix_tokens] x = x[:, self.num_prefix_tokens:] B = x.shape[0] L = x.shape[1] num_keep = max(1, int(L * (1. - self.prob))) keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] # 下面的三行代码是从原始代码修改而来的,使用了PyTorch的索引而不是torch.gather stride = L * torch.unsqueeze(torch.arange(B, device=x.device), 1) keep_indices = (stride + keep_indices).flatten() x = x.reshape(B * L, -1)[keep_indices].view(B, num_keep, -1) x = torch.cat((prefix_tokens, x), dim=1) return xmodel.patch_drop = MyPatchDropout( prob = model.patch_drop.prob, num_prefix_tokens = model.patch_drop.num_prefix_tokens)
在下面的图片中,我们捕捉到了以上更改后的跟踪视图:

我们可以清楚地看到,冗长的同步事件不再存在。
对于我们的示例模型来说,我们很幸运,torch.gather操作的使用方式使得可以用PyTorch的索引来替代。当然,并不总是这样;torch.gather的其他用法可能没有等价的基于索引的实现。
结果
在下表中,我们对训练不同场景下的玩具模型的性能结果进行了比较:

在我们的玩具示例中,优化虽然有一些影响,但影响并不大,大约提升了2%的性能。有趣的是,在可重现模式下,torch索引比默认的(非确定性的)torch.gather表现更好。根据这些发现,如果可能的话,使用索引而不是torch.gather可能是一个不错的选择。
总结
尽管PyTorch以易于调试和跟踪而闻名(当然是有道理的),但torch.autograd仍然有些神秘,分析训练步骤的反向传播可能非常困难。为了解决这个挑战,PyTorch提供了在反向传播的不同阶段插入钩子的支持。在本文中,我们展示了如何使用PyTorch的反向传播钩子以及torch.profiler.record_function,通过迭代过程来识别反向传播中性能问题的源头。我们将这种技术应用于一个简单的ViT模型,并了解了一些torch.gather操作的细微差别。
本文中,我们涵盖了一种非常特定的性能瓶颈。请务必查看我们关于VoAGI的其他文章,涵盖了与性能分析和性能优化相关的各种主题。



