通过register_forward_pre_hook、register_forward_hook、register_full_backward_pre_hook、register_full_backward_hook记录Module的耗时及输出结果;aten算子则通过torch_dispatch实现。之后对比不同的平台输出结果的耗时及相对误差并绘出直方图。可通过diff信息,定位从哪里开始出现误差
一.效果图
二.用法
model=YourModel()
obj=TimelineHook(model) #修改点一
with TorchDebugDumper(): # 修改点二
train(model)
三.代码
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import numpy as np
import torch.nn as nn
import time
@dataclass
class _ProfilerState:
cls: Any
object: Any = None
def is_tensor(val):
# 判断是否为tensor或Parameter
return isinstance(val, (torch.Tensor, nn.Parameter))
def is_tensor_empty(tensor):
return any(dim == 0 for dim in tensor.size())
def describe_tensor(tensor):
# 返回tensor的描述,包括形状和部分数据统计信息
if is_tensor_empty(tensor):
return ""
shape = [str(x) for x in tensor.shape]
#return f"{shape}"
tensor_data = tensor.cpu().float().detach().numpy().ravel()
num_points = min(16, len(tensor_data))
indices = np.linspace(0, len(tensor_data) - 1, num_points, dtype=int)
stats = [np.max(tensor_data), np.min(tensor_data), np.mean(tensor_data), np.std(tensor_data)]
sample_data = tensor_data[indices]
shape_str="-".join(shape)
stats_str = ",".join(f"{
x:.5f}" for x in stats)
sample_str = ",".join(f"{
x:.5f}" for x in sample_data)
return f"{
shape_str},{
stats_str},{
sample_str}"
def generate_random_data(shape):
# 生成符合指定形状的随机数据
max_val, min_val, mean, std = 0.04025, -0.04651, 0.0, 0.00134
data = np.random.normal(mean, std, shape)
data = (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_val
return data
index_counter = 0
def log_tensor_data(name, tensor,latency_ms=0):
# 打印tensor的日志数据
global index_counter
index_counter += 1
timestamp = datetime.now().strftime("%H%M%S%f")
rank=torch.distributed.get_rank()
# if rank!=3:
# return
with open(f"trace_rank_{
rank}.txt","a+") as f:
if is_tensor(tensor):
f.write(f"{
rank},{
timestamp},{
index_counter},{
name},{
latency_ms},0,{
describe_tensor(tensor)}\n")
elif isinstance(tensor, (tuple, list)):
for idx, t in enumerate(tensor):
if is_tensor(t):
f.write(f"{
rank}