pth模型转换成Tflite模型在PlatformIO平台给esp32部署

pthToonnx

example

import torch

import torch.nn as nn

import numpy as np

from lstm_model import SlidingCNNModel

  

def convert_to_onnx():

    # 加载模型参数

    input_size = 16  # 传感器特征数

    hidden_size = 512

    output_size = 6  # 输出维度(2个骨骼 × 3个参数)

    dropout = 0.3

    # 创建模型实例

    model = SlidingCNNModel(

        input_size=input_size,

        hidden_size=hidden_size,

        output_size=output_size,

        dropout=dropout

    )

    # 加载训练好的模型权重

    model.load_state_dict(torch.load('sliding_cnn_model.pth'))

    model.eval()

    # 创建示例输入

    # 注意:输入形状为 (batch_size, sequence_length, input_size)

    sequence_length = 6  # 与训练时使用的序列长度相同

    dummy_input = torch.randn(1, sequence_length, input_size)

    # 导出ONNX模型

    onnx_file_path = 'sliding_cnn_model.onnx'

    torch.onnx.export(

        model,                  # 要导出的模型

        dummy_input,           # 模型输入

        onnx_file_path,        # 保存路径

        export_params=True,    # 存储训练好的参数权重

        opset_version=11,      # ONNX算子集版本

        do_constant_folding=True,  # 是否执行常量折叠优化

        input_names=['input'],     # 输入名

        output_names=['output'],   # 输出名

        dynamic_axes={

            'input': {0: 'batch_size'},  # 只允许batch_size是动态的

            'output': {0: 'batch_size'}

        }

    )

    print(f"模型已成功转换为ONNX格式并保存到: {onnx_file_path}")

    print(f"输入形状: {dummy_input.shape}")

    print(f"模型结构:")

    print(model)

  

if __name__ == "__main__":

    convert_to_onnx()

ONNXtoTFlite

操作步骤

1. Environment Install

  • Install ONNX master branch from source.

  • Install TensorFlow >= 2.8.0, tensorflow-probability and tensorflow-addons. (Note TensorFlow 1.x is no longer supported)

  • Run git clone https://github.com/onnx/onnx-tensorflow.git.

  • Run cd onnx-tensorflow.

  • Run pip install -e ..

  • requirement.txt
    absl-py==2.1.0
    astunparse==1.6.3
    backcall==0.2.0
    cachetools==5.5.2
    certifi @ file:///C:/b/abs_85o_6fm0se/croot/certifi_1671487778835/work/certifi
    charset-normalizer==3.4.2
    cloudpickle==2.2.1
    decorator==5.1.1
    dm-tree==0.1.8
    flatbuffers==1.12
    gast==0.3.3
    google-auth==2.40.2
    google-auth-oauthlib==0.4.6
    google-pasta==0.2.0
    grpcio==1.32.0
    h5py==2.10.0
    idna==3.10
    importlib-metadata==6.7.0
    keras==2.9.0
    Keras-Preprocessing==1.1.2
    libclang==18.1.1
    Markdown==3.4.4
    MarkupSafe==2.1.5
    numpy==1.21.6
    oauthlib==3.2.2
    onnx==1.14.1
    onnx-tf==1.9.0
    opt-einsum==3.3.0
    packaging==24.0
    pickleshare==0.7.5
    protobuf==3.20.2
    pyasn1==0.5.1
    pyasn1-modules==0.3.0
    Pygments==2.17.2
    pywin32==306
    PyYAML==6.0.1
    pyzmq==26.0.3
    requests==2.31.0
    requests-oauthlib==2.0.0
    rsa==4.9.1
    six==1.15.0
    tensorboard==2.9.1
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorflow==2.9.0
    tensorflow-addons==0.19.0
    tensorflow-estimator==2.9.0
    tensorflow-intel==2.11.0
    tensorflow-io-gcs-filesystem==0.31.0
    tensorflow-probability==0.19.0
    termcolor==1.1.0
    tornado==6.2
    traitlets==5.9.0
    typeguard==2.13.3
    typing-extensions==3.7.4.3
    urllib3==2.0.7
    wcwidth==0.2.13
    Werkzeug==2.2.3
    wincertstore==0.2
    wrapt==1.12.1
    zipp==3.15.0
    tensorflow_model_optimization
    
    

2. 加载ONNX模型

        import onnx

        from onnx_tf.backend import prepare

        logger.info(f"加载ONNX模型: {onnx_model_path}")

        onnx_model = onnx.load(onnx_model_path)

3. 获取输入形状

        input_shape = None

        for input_info in onnx_model.graph.input:

            if input_info.name == 'input':

                input_shape = [dim.dim_value for dim in input_info.type.tensor_type.shape.dim]

                break

        if input_shape is None:

            raise ValueError("无法获取模型输入形状")

4. 转化为TensorFlow

  logger.info("转换为TensorFlow模型...")

        tf_rep = prepare(onnx_model)

        tf_rep.export_graph(tf_model_path)

5. 设置转化TFLite模型变量

  logger.info("转换为TFLite模型...")

        converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)

6. 设置优化选项

  if enable_optimization:

            logger.info("启用模型优化...")

            # 启用所有优化

            converter.optimizations = [

                tf.lite.Optimize.DEFAULT,

                tf.lite.Optimize.OPTIMIZE_FOR_LATENCY,

                tf.lite.Optimize.OPTIMIZE_FOR_SIZE

            ]

            # 启用XNNPACK委托

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            converter.target_spec.supported_types = [tf.float32]

            # 启用XNNPACK

            converter.target_spec.supported_ops.append(tf.lite.OpsSet.TFLITE_BUILTINS_INT8)

            converter.experimental_new_quantizer = True

            converter.experimental_new_converter = True

7. 设置量化选项

  if enable_quantization:

            logger.info("启用模型量化...")

            # 设置量化参数

            converter.representative_dataset = create_representative_dataset(input_shape, x_path)

            # 使用混合量化

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 保持输入输出为float32

            converter.inference_input_type = tf.float32

            converter.inference_output_type = tf.float32

            # 设置量化参数

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            # 启用新的量化器

            converter.experimental_new_quantizer = True

            # 设置量化范围

            converter.quantized_input_stats = {

                'input': (0.0, 1.0)  # (mean, std)

            }

            # 启用混合量化

            converter.experimental_new_converter = True

            # 设置量化模式为混合量化

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 设置量化参数

            converter.quantized_input_stats = {

                'input': (0.0, 1.0)  # (mean, std)

            }

            # 启用混合精度

            converter.target_spec.supported_types = [tf.float32, tf.int8]

8. 转换模型

  # 设置输入形状

        converter.input_shapes = {'input': [1, 6, 16]}  # 显式设置输入形状

        # 转换模型

        tflite_model = converter.convert()
        # 保存模型

        logger.info(f"保存TFLite模型到: {tflite_model_path}")

        with open(tflite_model_path, 'wb') as f:

            f.write(tflite_model)

9.转化为数组

  1. 下载vim并设置环境变量 https://www.vim.org/download.php

  2. 运行`xxd -i mnist.tflite model_data.h/.c

example

import os

import tensorflow as tf

import numpy as np

import logging

from datetime import datetime

import pandas as pd

from sklearn.metrics import mean_absolute_error

import math

import tensorflow_model_optimization as tfmot

  

# 设置环境变量以解决protobuf版本问题

os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

  

# 配置日志

logging.basicConfig(

    level=logging.INFO,

    format='%(asctime)s - %(levelname)s - %(message)s'

)

logger = logging.getLogger(__name__)

  

def read_csv_data(csv_path):

    """读取CSV文件的最后6行数据"""

    try:

        # 检查文件是否存在

        if not os.path.exists(csv_path):

            logger.error(f"文件不存在: {csv_path}")

            return None

        # 检查文件是否为空

        if os.path.getsize(csv_path) == 0:

            logger.error("CSV文件为空")

            return None

        # 尝试读取CSV文件,跳过第一行

        try:

            # 指定数据类型为float,跳过第一行

            df = pd.read_csv(csv_path, header=0, sep=',', dtype=float)

        except pd.errors.EmptyDataError:

            logger.error("CSV文件为空")

            return None

        except Exception as e:

            logger.error(f"读取CSV文件时出错: {e}")

            return None

        # 检查是否有数据

        if df.empty:

            logger.error("CSV文件没有数据")

            return None

        # 检查数据维度

        if df.shape[1] != 16:  # 确保有16列

            logger.error(f"CSV文件列数不正确,期望16列,实际{df.shape[1]}列")

            return None

        # 获取最后6行数据并转换为numpy数组

        last_six_rows = df.iloc[-6:].values.astype(np.float32)

        # 检查数据是否有效

        if np.isnan(last_six_rows).any():

            logger.error("数据包含NaN值")

            return None

        return last_six_rows

    except Exception as e:

        logger.error(f"处理CSV文件时出错: {e}")

        return None

  

def create_representative_dataset(input_shape, x_path):

    """

    使用实际数据创建用于量化的代表性数据集

    Args:

        input_shape: 输入形状

        x_path: 输入数据集路径 (.npy文件)

    """

    def representative_dataset():

        try:

            # 加载数据集

            X = np.load(x_path, allow_pickle=True)

            logger.info(f"加载量化数据集,形状: {X.shape}")

            # 使用所有数据进行量化

            for i in range(len(X)):

                data = X[i:i+1].astype(np.float32)  # 保持形状为(1, 6, 16)

                yield [data]

        except Exception as e:

            logger.error(f"创建代表性数据集时出错: {str(e)}")

            # 如果出错,使用随机数据

            logger.warning("使用随机数据作为备选")

            for _ in range(100):

                data = np.random.random((1, 6, 16)).astype(np.float32)

                yield [data]

    return representative_dataset

  

def apply_pruning_to_model(model, pruning_params):

    """

    对模型应用剪枝

    Args:

        model: TensorFlow模型

        pruning_params: 剪枝参数

    """

    try:

        # 定义剪枝参数

        pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(

            initial_sparsity=pruning_params['initial_sparsity'],

            final_sparsity=pruning_params['final_sparsity'],

            begin_step=pruning_params['begin_step'],

            end_step=pruning_params['end_step']

        )

        # 应用剪枝

        pruning_model = tfmot.sparsity.keras.prune_low_magnitude(

            model,

            pruning_schedule=pruning_schedule,

            block_size=pruning_params['block_size'],

            block_pooling_type=pruning_params['block_pooling_type']

        )

        return pruning_model

    except Exception as e:

        logger.error(f"应用剪枝时出错: {str(e)}")

        return model

  

def convert_onnx_to_tflite(onnx_model_path, tf_model_path, tflite_model_path,

                          x_path=None, enable_quantization=True, enable_optimization=True):

    try:

        logger.info("开始转换模型...")

        # 加载ONNX模型

        import onnx

        from onnx_tf.backend import prepare

        logger.info(f"加载ONNX模型: {onnx_model_path}")

        onnx_model = onnx.load(onnx_model_path)

        # 获取输入形状

        input_shape = None

        for input_info in onnx_model.graph.input:

            if input_info.name == 'input':

                input_shape = [dim.dim_value for dim in input_info.type.tensor_type.shape.dim]

                break

        if input_shape is None:

            raise ValueError("无法获取模型输入形状")

        logger.info(f"模型输入形状: {input_shape}")

        # 转换为TensorFlow模型

        logger.info("转换为TensorFlow模型...")

        tf_rep = prepare(onnx_model)

        tf_rep.export_graph(tf_model_path)

        # 转换为TFLite模型

        logger.info("转换为TFLite模型...")

        converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)

        # 设置优化选项

        if enable_optimization:

            logger.info("启用模型优化...")

            # 启用所有优化

            converter.optimizations = [

                tf.lite.Optimize.DEFAULT,

                tf.lite.Optimize.OPTIMIZE_FOR_LATENCY,

                tf.lite.Optimize.OPTIMIZE_FOR_SIZE

            ]

            # 启用XNNPACK委托

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            converter.target_spec.supported_types = [tf.float32]

            # 启用XNNPACK

            converter.target_spec.supported_ops.append(tf.lite.OpsSet.TFLITE_BUILTINS_INT8)

            converter.experimental_new_quantizer = True

            converter.experimental_new_converter = True

        # 设置量化选项

        if enable_quantization:

            logger.info("启用模型量化...")

            # 设置量化参数

            converter.representative_dataset = create_representative_dataset(input_shape, x_path)

            # 使用混合量化

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 保持输入输出为float32

            converter.inference_input_type = tf.float32

            converter.inference_output_type = tf.float32

            # 设置量化参数

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            # 启用新的量化器

            converter.experimental_new_quantizer = True

            # 设置量化范围

            converter.quantized_input_stats = {

                'input': (0.0, 1.0)  # (mean, std)

            }

            # 启用混合量化

            converter.experimental_new_converter = True

            # 设置量化模式为混合量化

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 设置量化参数

            converter.quantized_input_stats = {

                'input': (0.0, 1.0)  # (mean, std)

            }

            # 启用混合精度

            converter.target_spec.supported_types = [tf.float32, tf.int8]

            # 添加权重剪枝

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 设置权重剪枝参数

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

            # 启用权重剪枝

            converter.optimizations = [tf.lite.Optimize.DEFAULT]

            converter.target_spec.supported_ops = [

                tf.lite.OpsSet.TFLITE_BUILTINS_INT8,

                tf.lite.OpsSet.SELECT_TF_OPS

            ]

        # 设置输入形状

        converter.input_shapes = {'input': [1, 6, 16]}  # 显式设置输入形状

        # 转换模型

        tflite_model = converter.convert()

        # 保存模型

        logger.info(f"保存TFLite模型到: {tflite_model_path}")

        with open(tflite_model_path, 'wb') as f:

            f.write(tflite_model)

        # 获取模型大小

        model_size = os.path.getsize(tflite_model_path) / (1024 * 1024)  # 转换为MB

        logger.info(f"模型大小: {model_size:.2f} MB")

        # 验证模型

        logger.info("验证转换后的模型...")

        interpreter = tf.lite.Interpreter(model_content=tflite_model)

        interpreter.allocate_tensors()

        # 获取输入输出细节

        input_details = interpreter.get_input_details()

        output_details = interpreter.get_output_details()

        logger.info("\n模型信息:")

        logger.info(f"输入张量: {input_details[0]['name']}")

        logger.info(f"输入形状: {input_details[0]['shape']}")

        logger.info(f"输入类型: {input_details[0]['dtype']}")

        logger.info(f"输出张量: {output_details[0]['name']}")

        logger.info(f"输出形状: {output_details[0]['shape']}")

        logger.info(f"输出类型: {output_details[0]['dtype']}")

        logger.info("模型转换完成!")

        return True

    except Exception as e:

        logger.error(f"转换过程中出错: {str(e)}")

        return False

  

def test_model_performance(tflite_model_path, x_path, y_path):

    """

    测试TFLite模型的性能,计算MAE和RMSE

    Args:

        tflite_model_path: TFLite模型路径

        x_path: 输入数据集路径 (.npy文件)

        y_path: 输出数据集路径 (.npy文件)

    """

    try:

        logger.info("开始测试模型性能...")

        # 加载模型

        interpreter = tf.lite.Interpreter(model_path=tflite_model_path)

        interpreter.allocate_tensors()

        # 获取输入输出细节

        input_details = interpreter.get_input_details()

        output_details = interpreter.get_output_details()

        # 加载数据集

        logger.info(f"加载输入数据集: {x_path}")

        X = np.load(x_path, allow_pickle=True)

        logger.info(f"加载输出数据集: {y_path}")

        y_data = np.load(y_path, allow_pickle=True)

        logger.info(f"输入数据集形状: {X.shape}")

        logger.info(f"输出数据集形状: {y_data.shape}")

        # 处理输出数据

        if isinstance(y_data[0], dict):

            logger.info("检测到字典格式的输出数据,进行转换...")

            logger.info(f"字典键: {list(y_data[0].keys())}")

            # 处理R_Shoulder和R_Elbow数据

            y_true = []

            for sample in y_data:

                # 获取R_Shoulder和R_Elbow的值

                shoulder_data = sample['R_Shoulder']

                elbow_data = sample['R_Elbow']

                # 确保数据是数值类型

                if isinstance(shoulder_data, (list, tuple)):

                    shoulder_data = np.array(shoulder_data)

                if isinstance(elbow_data, (list, tuple)):

                    elbow_data = np.array(elbow_data)

                # 合并数据

                combined_data = np.concatenate([shoulder_data, elbow_data])

                y_true.append(combined_data)

            y_true = np.array(y_true)

            logger.info(f"处理后的输出数据形状: {y_true.shape}")

        else:

            y_true = y_data

        # 限制样本数为1000

        test_size = min(1000, len(X))

        logger.info(f"使用前{test_size}个样本进行测试")

        # 存储所有预测结果

        y_pred = []

        # 对每个样本进行预测

        for i in range(test_size):

            # 准备输入数据

            input_data = X[i:i+1].astype(np.float32)  # 保持形状为(1, 6, 16)

            # 检查所有数据是否都在(-0.1, 0.1)范围内

            if np.all(np.abs(input_data) < 0.1):

                logger.warning(f"样本 {i} 的所有数据都在(-0.1, 0.1)范围内,跳过预测")

                continue

            # 设置输入张量

            interpreter.set_tensor(input_details[0]['index'], input_data)

            # 运行推理

            interpreter.invoke()

            # 获取输出

            output_data = interpreter.get_tensor(output_details[0]['index'])

            y_pred.append(output_data.flatten())

        # 转换为numpy数组

        y_pred = np.array(y_pred)

        y_true = y_true[:len(y_pred)]  # 只使用对应的真实值

        # 确保预测结果和真实值的形状匹配

        if y_pred.shape != y_true.shape:

            logger.warning(f"预测结果形状 {y_pred.shape} 与真实值形状 {y_true.shape} 不匹配,进行调整...")

            y_pred = y_pred.reshape(y_true.shape)

        # 计算MAE和RMSE

        mae = mean_absolute_error(y_true, y_pred)

        rmse = math.sqrt(np.mean((y_true - y_pred) ** 2))

        # 计算每个样本的误差

        sample_errors = np.mean(np.abs(y_true - y_pred), axis=1)

        # 计算每个骨骼的误差

        bone_errors = np.mean(np.abs(y_true - y_pred), axis=0)

        logger.info("\n模型性能测试结果:")

        logger.info(f"测试样本数: {len(y_pred)}")

        logger.info(f"总体MAE: {mae:.6f}")

        logger.info(f"总体RMSE: {rmse:.6f}")

        logger.info(f"最小样本MAE: {np.min(sample_errors):.6f}")

        logger.info(f"最大样本MAE: {np.max(sample_errors):.6f}")

        logger.info(f"平均样本MAE: {np.mean(sample_errors):.6f}")

        logger.info(f"样本MAE标准差: {np.std(sample_errors):.6f}")

        # 输出每个骨骼的误差

        logger.info("\n每个骨骼的MAE:")

        bone_names = ['R_Shoulder_X', 'R_Shoulder_Y', 'R_Shoulder_Z',

                     'R_Elbow_X', 'R_Elbow_Y', 'R_Elbow_Z']

        for i, (name, error) in enumerate(zip(bone_names, bone_errors)):

            logger.info(f"{name}: {error:.6f}")

        return mae, rmse

    except Exception as e:

        logger.error(f"测试模型性能时出错: {str(e)}")

        return None, None

  

def main():

    # 设置模型路径

    onnx_model_path = "sliding_cnn_model.onnx"

    tf_model_path = "sliding_cnn_model"

    tflite_model_path = "sliding_cnn_model.tflite"

    x_path = r"D:/user/Work/graduate_student/research/software/DL/V1/data/X.npy"

    y_path = r"D:/user/Work/graduate_student/research/software/DL/V1/data/y.npy"

    # 检查ONNX模型是否存在

    if not os.path.exists(onnx_model_path):

        logger.error(f"找不到ONNX模型文件: {onnx_model_path}")

        return

    # 转换模型

    success = convert_onnx_to_tflite(

        onnx_model_path=onnx_model_path,

        tf_model_path=tf_model_path,

        tflite_model_path=tflite_model_path,

        x_path=x_path,  # 使用numpy数据集进行量化

        enable_quantization=True,

        enable_optimization=True

    )

    if success:

        logger.info("模型转换成功!")

        # 测试模型性能

        mae, rmse = test_model_performance(tflite_model_path, x_path, y_path)

        if mae is not None and rmse is not None:

            logger.info("模型性能测试完成!")

    else:

        logger.error("模型转换失败!")

  

if __name__ == "__main__":

    main()

03.在PlatformIO平台给esp32部署TFlite模型

Environment Install

  1. 点击libraries

 

  1. 安装TensorFlowLite_ESP32

 

  1. 添加到工程

 

Code

定义全局变量并分配内存

#define MODEL_INPUT_SIZE 6  // 模型输入窗口大小

#define MODEL_OUTPUT_SIZE 6 // 模型输出大小

  

// TFLite全局变量

const tflite::Model* model = nullptr;

tflite::MicroInterpreter* interpreter = nullptr;

TfLiteTensor* input = nullptr;

TfLiteTensor* output = nullptr;

// 创建错误报告器

static tflite::MicroErrorReporter micro_error_reporter;

tflite::ErrorReporter* error_reporter = &micro_error_reporter;

  

// 为TFLite分配内存 - 减少内存大小

constexpr int kTensorArenaSize = 20 * 1024;  // 减少到50KB

uint8_t tensor_arena[kTensorArenaSize] __attribute__((aligned(16)));  // 16字节对齐

  

// 数据缓存 - 使用较小的缓冲区

float inputBuffer[MODEL_INPUT_SIZE][ADC_CHANNEL_NUM] = {0};

float outputBuffer[MODEL_OUTPUT_SIZE] = {0};

setup

加载模型

 model = tflite::GetModel(sliding_cnn_model_tflite);

    if (model->version() != TFLITE_SCHEMA_VERSION) {

        TF_LITE_REPORT_ERROR(error_reporter, "Model schema mismatch!");

        return false;

    }

创建并配置变量参数

 // 创建操作解析器 - 只添加必要的操作

    static tflite::MicroMutableOpResolver<1> resolver;

  

  

    // 创建解释器

    static tflite::MicroInterpreter static_interpreter(

        model, resolver, tensor_arena, kTensorArenaSize, error_reporter);

    interpreter = &static_interpreter;

  

    // 分配张量

    if (interpreter->AllocateTensors() != kTfLiteOk) {

        TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");

        return false;

    }

  

    // 获取输入输出张量

    input = interpreter->input(0);

    output = interpreter->output(0);

loop

执行推理

 // 执行推理

    if (interpreter->Invoke() != kTfLiteOk) {

        TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed!");

        return;

    }

获取输出

for (int i = 0; i < MODEL_OUTPUT_SIZE; i++) {

        outputBuffer[i] = output->data.f[i];

    }

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值