在Pytorch 2.x中,引入了torch.compile特性,主要包含如下4个部分:
TorchDynamo:基于Python Frame Evaluation Hook技术,实现安全的Pytorch的计算图捕获。
AOTAutograd: AOT生成计算图的反向图。
PrimTorch:规范化2000+ PyTorch Operators为250+ Primitive Operators, 极大降低了开发Pytorch后端的难度。
TorchInductor:一个Deep Learning Compiler,为多种加速器生成高性能代码。对NVIDIA和AMD GPUs, 使用OpenAI Triton编译器作为Backend。
torch.compile编译过程如下:
在图编译视角下,Pytorch的软件栈如下,Triton是Inductor的一个Codegen Backend:
参考资料:
PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever – PyTorch