PyTorch 2.2大更新!集成FlashAttention-2,性能提升2倍
新智元报道
新智元报道
【新智元导读】新的一年,PyTorch也迎来了重大更新,PyTorch 2.2集成了FlashAttention-2和AOTInductor等新特性,计算性能翻倍。
新的一年,PyTorch也迎来了重大更新!
继去年十月份的PyTorch大会发布了2.1版本之后,全世界各地的521位开发者贡献了3628个提交,由此形成了最新的PyTorch 2.2版本。
新的版本集成了FlashAttention-2,使得scaled_dot_product_attention (SDPA)相较于之前的版本有了约2倍的性能提升。
PyTorch 2.2还引入了一个新的TorchInductor提前扩展,称为 AOTInductor,旨在为非python服务器端编译和部署PyTorch程序。
PyTorch中的torch.distributed支持了一个叫做device_mesh的新抽象,用于初始化和表示ProcessGroups。
另外,PyTorch 2.2提供了一个标准化的、可配置的日志记录机制,——TORCH_LOGS。
PyTorch 2.2还对torch.compile做了许多改进,包括改进了对编译优化器的支持,以及TorchInductor融合和布局优化。
最后值得注意的是,PyTorch将放弃对macOS x86的支持,PyTorch 2.2.x是支持macOS x64的最后一个版本。
PyTorch 2.2新特性
FlashAttention-2
FlashAttention-2通过优化GPU上不同线程块和warps之间的工作分区,来解决占用率低或不必要的共享内存读写。
FlashAttention-2调整了算法以减少非matmul的计算量,同时提升了Attention计算的并行性(即使是单个头,也可以跨不同的线程块,以增加占用率),在每个线程块中,优化warps之间的工作分配,以减少通过共享内存的通信。
PyTorch 2.2将FlashAttention内核更新到了v2版本,不过需要注意的是,之前的Flash Attention内核具有Windows实现,Windows用户可以强制使用sdp_kernel,仅启用Flash Attention的上下文管理器。
而在2.2中,如果必须使用 sdp_kernel 上下文管理器,请使用memory efficient或math内核(在Windows上)。
在FlashAttention-2的加持之下,torch.nn.functional.scaled_dot_product_attention的速度提升了大约2倍,在A100 GPU上达到了理论计算峰值的50%-73%。
AOTInductor
AOTInductor是TorchInductor的扩展,用于处理导出的PyTorch模型,对其进行优化,并生成共享库以及其他相关工件。
这些编译的工件可以部署在非Python环境中,经常用于服务器端的推理。
下面的示例演示了如何调用 aot_compile 将模型转换为共享库。
AOTInductor支持与Inductor相同的后端,包括CUDA、ROCm和CPU。
TORCH_LOGS
PyTorch 2.2提供了一个标准化的、可配置的日志记录机制,可用于分析各种子系统的状态,例如编译和分布式操作
可以通过TORCH_LOGS环境变量启用日志。比如通过在命令行中修改环境变量:
将TorchDynamo的日志级别设置为logging.ERROR,将TorchInductor的日志级别设置为logging.DEBUG。
当然也可以在代码中以API的形式使用:
torch.distributed.device_mesh
PyTorch 2.2引入了一个新的抽象,用于表示分布式并行中涉及的 ProcessGroups,称为torch.distributed.device_mesh。
为分布式训练设置分布式通信器(NCCL)是一件麻烦的事情。用户需要编写不同并行度的工作负载,并为每个并行度手动设置和管理NCCL通信器(ProcessGroup )。
这个过程可能很复杂,容易出错。而DeviceMesh 可以简化此过程,使其更易于管理。
DeviceMesh 是管理 ProcessGroup 的更高级别的抽象。它允许用户毫不费力地创建节点间和节点内进程组,而不必担心如何为不同的子进程组正确设置等级。
例如,数组的其中一个维度可以表示FSDP中的数据并行(data parallelism),而另一个维度可以表示FSDP中的张量并行(tensor parallelism)。
用户还可以通过 DeviceMesh 轻松管理底层process_groups,以实现多维并行。
DeviceMesh在处理多维并行性(如3D并行)时很有用。如上图所示,当你的并行解决方案需要跨主机和每个主机内部进行通信时,可以创建一个2D网格,用于连接每个主机中的设备,并以同构设置将每个设备与其他主机上的对应设备连接起来。
借助 init_device_mesh() ,我们可以在短短两行内完成上面这个2D设置:
而如果不使用DeviceMesh,我们大概需要自己写下面这一堆代码:
当然,如果需要,我们仍然可以访问底层 ProcessGroup:
优化器的改进
大概有以下几点:
编译优化器在所有基准测试中都提高了性能:HuggingFace +18%、TorchBench +19%、TIMM +8% E2E;
编译的优化器增加对cudagraphs的支持;
对测试套件中所有模型进行平均,每个测试套件的基准测试平均编译时间增加约40秒;正在进行的优化可能会将其降低到30秒以下。
性能改进
微信扫码关注该文公众号作者