pytorch模型精度对比工具

pytorch模型精度对比工具

通过register_forward_pre_hook、register_forward_hook、register_full_backward_pre_hook、register_full_backward_hook记录Module的耗时及输出结果;aten算子则通过torch_dispatch实现。之后对比不同的平台输出结果的耗时及相对误差并绘出直方图。可通过diff信息,定位从哪里开始出现误差

一.效果图

在这里插入图片描述

请添加图片描述

二.用法

model=YourModel()
obj=TimelineHook(model)  #修改点一
with TorchDebugDumper(): # 修改点二
    train(model)

三.代码

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import numpy as np
import torch.nn as nn
import time

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

def is_tensor(val):
    # 判断是否为tensor或Parameter
    return isinstance(val, (torch.Tensor, nn.Parameter))

def is_tensor_empty(tensor):
    return any(dim == 0 for dim in tensor.size())

def describe_tensor(tensor):
    # 返回tensor的描述,包括形状和部分数据统计信息
    if is_tensor_empty(tensor):
        return ""
    shape = [str(x) for x in tensor.shape]
    #return f"{shape}"
    tensor_data = tensor.cpu().float().detach().numpy().ravel()
    num_points = min(16, len(tensor_data))
    indices = np.linspace(0, len(tensor_data) - 1, num_points, dtype=int)
    stats = [np.max(tensor_data), np.min(tensor_data), np.mean(tensor_data), np.std(tensor_data)]
    sample_data = tensor_data[indices]
    shape_str="-".join(shape)
    stats_str = ",".join(f"{
     
     x:.5f}" for x in stats)
    sample_str = ",".join(f"{
     
     x:.5f}" for x in sample_data)
    return f"{
     
     shape_str},{
     
     stats_str},{
     
     sample_str}"

def generate_random_data(shape):
    # 生成符合指定形状的随机数据
    max_val, min_val, mean, std = 0.04025, -0.04651, 0.0, 0.00134
    data = np.random.normal(mean, std, shape)
    data = (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_val
    return data

index_counter = 0
def log_tensor_data(name, tensor,latency_ms=0):
    # 打印tensor的日志数据
    global index_counter
    index_counter += 1
    timestamp = datetime.now().strftime("%H%M%S%f")
    rank=torch.distributed.get_rank()
    # if rank!=3:
    #     return
    with open(f"trace_rank_{
     
     rank}.txt","a+") as f:
        if is_tensor(tensor):
            f.write(f"{
     
     rank},{
     
     timestamp},{
     
     index_counter},{
     
     name},{
     
     latency_ms},0,{
     
     describe_tensor(tensor)}\n")
        elif isinstance(tensor, (tuple, list)):
            for idx, t in enumerate(tensor):
                if is_tensor(t):
                    f.write(f"{
     
     rank}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hi20240217

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值