嵌入式小企鹅 2024-11-20 17:07 采纳率: 0%
浏览 14
已结题

深度学习输入输出张量怎么确认?

我在
https://github.com/deezer/spleeter/releases
这里下载了2stems.tar.gz,拿到之后解压之后有checkpoint、.data-00000-of-00001、.index、.meta。我现在需要的是转换模型为 TensorFlow Lite 格式。我要拿到Android这种设备上去。现在就是说好像要先转换为 SavedModel 之后,才能转换为TensorFlow Lite 格式。现在我输入下面的代码

import tensorflow as tf

# 禁用 Eager Execution
tf.compat.v1.disable_eager_execution()

# 加载 .meta 文件并恢复模型
saver = tf.compat.v1.train.import_meta_graph('./model.meta')
with tf.compat.v1.Session() as sess:
    saver.restore(sess, './model')

    # 获取模型图
    graph = tf.compat.v1.get_default_graph()

    # 输入和输出张量名称
    input_tensor_name = 'keras_learning_phase/input'  # 假设这是输入张量的名称
    output_tensor_name = 'Sum/input'  # 假设这是输出张量的名称

    # 创建 Keras 模型
    model = tf.keras.models.Model(inputs=graph.get_tensor_by_name(input_tensor_name + ':0'),
                                  outputs=graph.get_tensor_by_name(output_tensor_name + ':0'))
    
    # 转换为 TensorFlow Lite 模型
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()

    # 保存 TensorFlow Lite 模型
    with open('model.tflite', 'wb') as f:
        f.write(tflite_model)

print("模型已成功转换为 TensorFlow Lite 格式!")

报错,主要就是说这两个输入输出张量应该怎么填?

  • 写回答

1条回答 默认 最新

  • 清水白石008 2024-11-20 17:10
    关注
    🕒下午好,题主!目前:🟢[在线]
    📢参考GPT-4o mini 模型、由清水白石008整理提供!
    📣今日金句:“成功不是终点,失败也不是致命的,继续前进的勇气才是最重要的。” — 温斯顿·丘吉尔
    

    为了确定模型的输入和输出张量名称,我们需要检查模型的图结构。你可以使用TensorFlow提供的工具来查看模型中的所有操作,并找出输入和输出张量的名称。以下是一个详细的步骤:

    查看模型图结构

    1. 加载模型并打印所有操作
    import tensorflow as tf
    
    # 禁用 Eager Execution
    tf.compat.v1.disable_eager_execution()
    
    # 加载 .meta 文件并恢复模型
    saver = tf.compat.v1.train.import_meta_graph('./model.meta')
    with tf.compat.v1.Session() as sess:
        saver.restore(sess, './model')
    
        # 获取模型图
        graph = tf.compat.v1.get_default_graph()
    
        # 打印所有操作名称
        for op in graph.get_operations():
            print(op.name)
    

    通过运行上述代码,你可以打印出模型中所有操作的名称。观察这些操作名称,可以帮助你确定模型的输入和输出张量。

    确定输入和输出张量名称

    根据打印出的操作列表,找到合适的输入和输出张量名称,并替换你当前代码中的 input_tensor_nameoutput_tensor_name

    假设你找到了合适的操作名称,以下是修改后的代码:

    import tensorflow as tf
    
    # 禁用 Eager Execution
    tf.compat.v1.disable_eager_execution()
    
    # 加载 .meta 文件并恢复模型
    saver = tf.compat.v1.train.import_meta_graph('./model.meta')
    with tf.compat.v1.Session() as sess:
        saver.restore(sess, './model')
    
        # 获取模型图
        graph = tf.compat.v1.get_default_graph()
    
        # 输入和输出张量名称(根据你的模型图结构修改)
        input_tensor_name = 'input_tensor_name_from_graph'  # 修改为实际输入张量名称
        output_tensor_name = 'output_tensor_name_from_graph'  # 修改为实际输出张量名称
    
        # 创建 Keras 模型
        model = tf.keras.models.Model(inputs=graph.get_tensor_by_name(input_tensor_name + ':0'),
                                      outputs=graph.get_tensor_by_name(output_tensor_name + ':0'))
        
        # 转换为 TensorFlow Lite 模型
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite_model = converter.convert()
    
        # 保存 TensorFlow Lite 模型
        with open('model.tflite', 'wb') as f:
            f.write(tflite_model)
    
    print("模型已成功转换为 TensorFlow Lite 格式!")
    

    完整步骤

    1. 加载模型并打印所有操作:运行第一段代码,找出所有操作名称。
    2. 确定输入和输出张量名称:根据打印出的操作列表,找到输入和输出张量名称。
    3. 修改代码:将找到的输入和输出张量名称替换到第二段代码中,并运行。

    这将确保你能正确地加载模型并进行TensorFlow Lite转换。

    希望这对你有所帮助!如果你有任何进一步的问题或需要更多的详细指导,请随时告诉我。😊

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 11月27日
  • 创建了问题 11月20日