最近用 torch 重构了一个语音关键词检测的 TensorFlow 项目,CPU 推理的时候使用以下关键代码将推理耗时从1s
降到0.2s
dummy = torch.randn(1, 8000)
model = torch.jit.trace(model, dummy)
这两行代码的作用是把原本的 Python-定义模型,转成一个静态的、优化过的 TorchScript 模型,从而减少每次调用时的 Python 解释开销,加速推理。具体来说:
-
dummy = torch.randn(1, 8000)
- 用途:构造一个“伪输入”(dummy input),与模型在实际推理时接收的数据形状完全一致。
- 含义:这里假设你的模型接受形状
[batch_size, time_steps]
的浮点张量,也就是一维信号长度是 8000。torch.randn(1, 8000)
会创建一个值服从标准正态分布的 Tensor,shape 为[1, 8000]
,用来驱动一次前向计算。
-
model = torch.jit.trace(model, dummy)
-
作用:用 TorchScript 的 Tracing 方式,把原本一连串 Python 操作(卷积、线性层、激活等)“录制”下来,编译成一个独立于 Python 的计算图(intermediate representation)。
-
好处:
- 减少解释器开销:每次
.forward()
不再跑 Python 代码,而是执行一个静态图; - 图优化:PyTorch 会对图做一些融合、内联等优化;
- 跨平台部署:生成的 TorchScript 模型可以在 C++、Java、移动端等环境下加载,无需 Python 解释器。
- 减少解释器开销:每次
-
一个比较好的举例:
- 第一行是“干一杯”——造一份跟真实输入同样形状的随机数据;
- 第二行是“放到高压锅里”——把你的 PyTorch 模型蒸成一个静态、优化过的 TorchScript 版本,加速并简化部署。