【tensorflow】利用ckpt文件计算模型的参数量(parameters calculation)

该博客介绍了一段使用TensorFlow Python API来计算模型参数量的代码。通过加载模型的checkpoint文件,遍历变量并计算其形状的乘积,从而得到模型的总参数数量,并以百万为单位输出。此方法适用于理解和量化深度学习模型的规模。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • 计算模型参数量大小,代码如下:
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
import argparse

def main(args):
    model_dir = args["model_dir_path"]
    detecotr = TOD()
    detecotr.paramstest(model_dir)

class TOD(object):
    def __init__(self):
        self.self = self

    def paramstest(self,model_dir):
        checkpoint_path = os.path.join(model_dir, "model.ckpt")
        reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
        var_to_shape_map = reader.get_variable_to_shape_map()
        total_parameters = 0
        for key in var_to_shape_map:#list the keys of the model
            # print(key)
            # print(reader.get_tensor(key))
            shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
            shape = list(shape)
            # print(shape)
            # print(len(shape))
            variable_parameters = 1
            for dim in shape:
                # print(dim)
                variable_parameters *= dim
            # print(variable_parameters)
            total_parameters += variable_parameters
        total_parameters = total_parameters / 1000000
        print("=========================================================")
        print("Finish! The parameters of this model!")
        print(str(total_parameters)+ "M parameters.")


if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument("-model",
                    "--model_dir_path",
                    required=True,
                    help="directory path of model files")

    args = vars(ap.parse_args())
    main(args)
  • 使用时的代码:
python3 params.py -model ckpt文件所在目录的绝对路径
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值