基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

想知道Megatron-DeepSpeed训练过程中各模块之间的调用关系。torch_dispatch机制可以拦截算子,inspect又能获取到调用栈(文件,类名,函数,行号).基于这些信息可以生成调用关系,最后用graphviz生成SVG图像。该思路也可以用来画其它pytorch工程的调用关系图

1.为了减少图像宽度,一行显示一级文件路径
2.没有显示具体的ATen算子。因为边太乱

一.局部效果图

在这里插入图片描述

二.运行训练过程,拦截算子,生成调用关系信息

# 前面构建模型的代码省略
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
from dataclasses import dataclass
from typing import Any
import pickle

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent        
        self.global_index=0        
        self.nodes=set()
        self.edges=set()
                
    def __del__(self):
        self.rank = torch.distributed.get_rank()
        graph={"nodes":self.nodes,"edges":self.edges}
        with open(f"call_graph_{self.rank}.pkl","wb") as f:
            pickle.dump(graph,f)
    
    def is_keep(self,node):
        # if node.function.find("wrapper")>=0:
        #     return False
        # if node.function.find("_call_impl")>=0:
        #     return False
        return True

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.global_index+=1
        self.rank = torch.distributed.get_rank() 
        func_packet = func._overloadpacket       
        if kwargs is None:
            kwargs = {}
        if self.rank==0:
            stacks=[i for i in inspect.stack() if self.is_keep(i)]
            stacks_sz=len(stacks)
            for idx in range(stacks_sz-1,1,-1):
                if "self" in stacks[idx].frame.f_locals:
                    class_name = stacks[idx].frame.f_locals["self"].__class__.__name__
                else:
                    class_name=""
                this_node=f"{stacks[idx].filename}:[{class_name}]:{stacks[idx].function}"

                if "self" in stacks[idx-1].frame.f_locals:
                    class_name = stacks[idx-1].frame.f_locals["self"].__class__.__name__
                else:
                    class_name=""                                    
                next_node=f"{stacks[idx-1].filename}:[{class_name}]:{stacks[idx-1].function}"
                self.nodes.add(this_node)
                self.nodes.add(next_node)
                self.edges.add(f"{this_node}->{next_node}")
            # if stacks_sz>1:
            #     if "self" in stacks[1].frame.f_locals:
            #         class_name = stacks[1].frame.f_locals["self"].__class__.__name__
            #     else:
            #         class_name=""                
            #     this_node=f"{stacks[1].filename}:[{class_name}]:{stacks[1].function}"
            #     next_node=f"{func_packet.__name__}"
            #     self.nodes.add(this_node)   
            #     self.nodes.add(next_node)            
            #     self.edges.add(f"{this_node}->{next_node}")
        ret= func(*args, **kwargs)
        return ret
 
class TorchDumper:
    _CURRENT_Dumper = None
    def __init__(self,schedule: Any):
        self.p= _ProfilerState(schedule)
 
    def __enter__(self):
        assert TorchDumper._CURRENT_Dumper is None
        TorchDumper._CURRENT_Dumper = self
        if self.p.object is None:
            o = self.p.cls(self)
            o.__enter__()
            self.p.object = o
        else:
            self.p.object.step()
        return self
 
    def __exit__(self, exc_type, exc_val, exc_tb):
        TorchDumper._CURRENT_Dumper = None
        if self.p.object is not None:
            self.p.object.__exit__(exc_type, exc_val, exc_tb)
            del self.p.object  #序列化保存

def main():
    with TorchDumper(TorchDumpDispatchMode):
        #训练入口
        pretrain(
            train_valid_test_datasets_provider,
            model_provider,
            forward_step,
            extra_args_provider=llama_argument_handler,
            args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
        )

if __name__ == "__main__":
    main()

三.可视化,生成SVG图像

# coding=utf-8

import os
from graphviz import Digraph,Graph
import pickle
import random
from distinctipy import distinctipy

def generate_colors(N):
    '''
    生成N种有区别度的颜色
    '''
    result=[]
    for red, green, blue in distinctipy.get_colors(N):
        result.append("#{:02X}{:02X}{:02X}".format(int(red*255), int(green*255), int(blue*255)))
    return result

def replace_name(name):
    '''
    修改节点名字(缩短,添加换行)
    '''
    if name.find("__torch_dispatch__")>=0:
        return None
    name=name.replace("/home/user/Megatron-DeepSpeed/","")
    name=name.replace("/home/anaconda3/envs/dev/lib/python3.10/site-packages/","")
    name=name.replace("/home/user/deepspeed/","")
    name=name.replace("/home/anaconda3/envs/dev/","")
    name=name.replace("/",r"\n")
    name=name.replace(":",r"\n")
    return name

# 1.加载HOOK生成的调用关系文件
rank=0
with open(f"call_graph_{rank}.pkl","rb") as f:
    data=pickle.load(f)

# 2.构建图,设置属性
dot = Digraph()
dot.node_attr = {"shape": "plaintext"}
dot.attr('graph', layout='dot')
dot.graph_attr.update(sep='4.0', ratio='compress')

node_desc_id_map={}  #节点名与描述的关系映射表
src_node_color={}    #节点颜色映射表(同一个节点输出的边颜色一样)

colors = generate_colors(10)
colors_sz=len(colors)

fontsize="16"        #节点字体大小
penwidth="2.0"       #边宽度

# 3.添加节点
for idx,v in enumerate(data["nodes"]):
    v=replace_name(v)
    if v is None:
        continue
    node_desc_id_map[v]=f"{idx}"

    if v.find("megatron")>=0:
        dot.node(f"{idx}",v,style='filled',color='#73FBFD',fontsize=fontsize)
    elif v.find("deepspeed")>=0:
        dot.node(f"{idx}",v,style='filled',color='#FA8D89',fontsize=fontsize)
    else:
        dot.node(f"{idx}",v,style='filled',color='#C0C0C0',fontsize=fontsize)
    src_node_color[v]=colors[idx%colors_sz]

# 4.添加边
for edge in data["edges"]:
    from_node,to_node=edge.split("->")
    from_node=replace_name(from_node)
    to_node=replace_name(to_node)
    if all([from_node,to_node]):
        color=src_node_color[from_node]
        dot.edge(node_desc_id_map[from_node], node_desc_id_map[to_node],color=color,penwidth=penwidth)

# 5.保存SVG
save_path='megatron_deepspeed_callgraph'
dot.render(save_path,format='svg', view=False)

# 6.修改背景色为灰色
import xml.etree.ElementTree as ET
svg_tree = ET.parse(f'{save_path}.svg')
root = svg_tree.getroot()
element = root.find(".//{http://www.w3.org/2000/svg}polygon")
element.set('fill', 'gray')
svg_tree.write(f'{save_path}.svg')
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hi20240217

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值