基于AM67A的嵌入式语音识别功能实现(二):模型训练及转换onnx格式

1. 挑选模型

我从GitHub上直接测试了多个模型,最终选用了效果较好的MASR中文语音识别项目,配合AISHELL-1数据集进行部署,主要目标是确保系统能够成功运行。

2. 训练及验证

由于训练过程对电脑配置要求较高,考虑到我的GTX 1060 6G显卡性能有限,完整数据集的训练周期需要一个月。为了提高效率,我将数据集规模缩减至前7000条样本。

# mini_dataset.py
def miniData(input, output):
    i = 0
    with open(input, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            i += 1
            if(i > 7000): break
            data.append(line.strip())

    with open(output,'w' , encoding='utf-8') as out:
        for item in data:
            out.write(item + '\n')

data = []
input = 'dev.manifest'
output = 'dev-mini.manifest'
miniData(input, output)

训练参数设置为:

	# train.py
	epochs=1000,
	batch_size=64,
	train_index_path="data_asishell/train-sort-mini.manifest",
	dev_index_path="data_asishell/dev-mini.manifest",
	labels_path="data_asishell/labels.json",
	learning_rate=0.6,
	momentum=0.8,
	max_grad_norm=0.2,
	weight_decay=0,

笔记本连续运行了24小时后,考虑到设备负荷问题停止了训练。最终模型准确率达到70%,作为流程验证的基准模型已满足需求。待后续流程完善后,将进一步优化训练效果。通过测试集的特定样本进行验证:

# test.py
import torch
import sys
import json
import data
from decoder import GreedyDecoder

# 加载标签表
with open("data_asishell/labels.json", encoding="utf-8") as f:
    labels = json.load(f)

# 加载模型(模型是直接保存的 model 对象)
model = torch.load("pretrained/model_18.pth")
model.eval()
model.to("cuda")

# 音频路径
wav_path = "test1.wav" if len(sys.argv) <= 1 else sys.argv[1]

# 加载音频并提取特征
waveform = data.load_audio(wav_path)
spec = data.spectrogram(waveform).unsqueeze(0).to("cuda")  # 加上 batch 维度
print(f"模型输入尺寸:{spec.shape}")  # 例:[1, 80, 215]
input_lengths = torch.IntTensor([spec.size(2)])  # 长度信息

# 推理
with torch.no_grad():
    output, output_lengths = model(spec, input_lengths)
    output = torch.softmax(output, dim=1).transpose(1, 2)  # (B, T, C)

# 解码
decoder = GreedyDecoder(labels)
decoded, _ = decoder.decode(output, output_lengths)

# 打印结果
#print("\n识别结果:")
#print(decoded[0][0])
#print(f"识别结果:{decoded[0][0]}\n真实结果:甚至出现交易几乎停滞的情况")
print(f"识别结果:{decoded[0][0]}\n真实结果:这令被贷款的员工们寝食难安")

最后识别结果为:

识别结果:这零被贷款的员部们请示男案
真实结果:这令被贷款的员工们寝食难安

3. 转换为.onnx格式

# export.py
import torch
import json
from models.conv import GatedConv

# 加载标签,构建模型
with open("data_asishell/labels.json", encoding="utf-8") as f:
    labels = json.load(f)
    vocab_str = "".join(labels)

# 初始化模型并加载权重
model = GatedConv(vocab_str)
model.load_state_dict(torch.load("pretrained/model_18.pth", map_location="cpu").state_dict())
model.eval()

# 构造 dummy 输入(根据你实际打印出来的是 [1, 161, 421])
dummy_spec = torch.randn(1, 161, 421)  # [batch, n_mels, time]
dummy_input_lengths = torch.IntTensor([421])  # 每条语音帧长度

# 导出 ONNX
torch.onnx.export(
    model,
    (dummy_spec, dummy_input_lengths),
    "asr_model.onnx",
    input_names=["spec", "input_lengths"],
    output_names=["logits", "output_lengths"],
    dynamic_axes={
        "spec": {0: "batch", 2: "time"},           # 支持动态 batch 和 time
        "input_lengths": {0: "batch"},
        "logits": {0: "batch", 2: "time"},
        "output_lengths": {0: "batch"}
    },
    opset_version=11,
    do_constant_folding=True
)

print("✅ 成功导出 ONNX 模型:asr_model.onnx")

需要注意输入输出的维度,有的自然语言模型time是可变的

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值