1 原理
在深度学习中,PTQ(Post-Training Quantization,训练后量化) 是一种模型压缩与加速技术,核心是在模型训练完成后,将模型中的高精度参数(如 32 位浮点数 FP32)转换为低精度格式(如 8 位整数 INT8、16 位浮点数 FP16),以实现 “减少模型体积、降低内存占用、加速推理速度” 的目标,同时尽可能保留模型的精度。它是工业界部署深度学习模型(尤其是边缘设备、移动端)的核心技术之一,因为其无需重新训练模型,开发成本低、周期短。
基本上的意思就是训练完之后,对模型做量化Quantization。这样省去了训练模型这个成本最大的步骤。
下面是对基准和做了PTQ模型的对比,可以看到,准确率下降都不算多。
除了PTQ,还有一个就是QAT(Quantization-Aware Training,训练时量化),这个是在模型训练过程中(模拟量化误差,更新参数),准确率损失更小,同时可以完全适配不同的硬件。但是就是成本高,需要重新训练整个模型,周期长。
这里是基准,PTQ,QAT的对比。可以看到,QAT在性能上确实要好一点,但是也就一点而已。。
之后比较了一下TensorFlow和TensorFlow Lite的区别。本质就是TensorFlow可以部署在集群上,可以随时重新训练,是一个迭代过程。而TensorFlow Lite只能算一个应用,参数固定无法调整,同时对大小非常敏感。
2 学习代码
2.1 量化模拟和操作
整理的代码:
# For Numpy
import matplotlib.pyplot as plt
import numpy as np
import pprint
import re
import sys
# For TensorFlow Lite (also uses some of the above)
import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)
import tensorflow as tf
from tensorflow import keras
import pathlib
import pprint
import re
import sys
weights = np.random.randn(256, 256)
def quantizeAndReconstruct(weights):
"""
@param W: np.ndarray
This function computes the scale value to map fp32 values to int8. The function returns a weight matrix in fp32, that is representable
using 8-bits.
"""
# Compute the range of the weight.
max_weight = np.max(weights)
min_weight = np.min(weights)
range = max_weight - min_weight
max_int8 = 2**8
# Compute the scale
scale = range / max_int8
# Compute the midpoint
midpoint = np.mean([max_weight, min_weight])
# Next, we need to map the real fp32 values to the integers. For this, we make use of the computed scale. By diving the weight
# matrix with the scale, the weight matrix has a range between (-128, 127). Now, we can simply round the full precision numbers
# to the closest integers.
centered_weights = weights - midpoint
quantized_weights = np.rint(centered_weights / scale)
# Now, we can reconstruct the values back to fp32.
reconstructed_weights = scale * quantized_weights + midpoint
return reconstructed_weights
reconstructed_weights = quantizeAndReconstruct(weights)
print("Original weight matrix\n", weights)
print("Weight Matrix after reconstruction\n", reconstructed_weights)
errors = reconstructed_weights-weights
max_error = np.max(errors)
print("Max Error : ", max_error)
reconstructed_weights.shape
# We can use np.unique to check the number of unique floating point numbers in the weight matrix.
np.unique(quantizeAndReconstruct(weights)).shape
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=1,
validation_data=(test_images, test_labels)
)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_models_dir = pathlib.Path("/content/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Convert the model using DEFAULT optimizations: https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/python/lite.py#L91-L130
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir / "mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)
!ls -lh /content/mnist_tflite_models/
列其中重要的看一下。
def quantizeAndReconstruct(weights),这个函数说明了量化的原理。
核心就是quantized_weights = np.rint(centered_weights / scale),这个调用将浮点数映射到了整数。
后面的代码训练一个简单的MNIST CNN,之后转换到 TFLite 格式保存两个版本:浮点模型(mnist_model.tflite),之后使用DEFAULT优化成量化模型(mnist_model_quant.tflite,使用 TFLite 默认优化,包括权重量化,激活值据说这里不会优化)
# Convert the model using DEFAULT optimizations: https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/python/lite.py#L91-L130
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir / "mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)
这里是优化后的模型,体积减为大概1/4。
total 540K
-rw-r--r-- 1 root root 112K Aug 31 13:23 mnist_model_quant.tflite
-rw-r--r-- 1 root root 428K Aug 31 13:23 mnist_model.tflite
2.2 分析量化后的模型
代码:
# Build and install the Flatbuffer compiler.
!apt-get install -y cmake
%cd /content/
!rm -rf flatbuffers*
!curl -L "https://github.com/google/flatbuffers/archive/v1.12.0.zip" -o flatbuffers.zip
!unzip -q flatbuffers.zip
!mv flatbuffers-1.12.0 flatbuffers
%cd flatbuffers
!cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-error=class-memaccess"
#make
!make -j 8
!cp flatc /usr/local/bin/
%cd /content/
!rm -rf tensorflow
!git clone --depth 1 https://github.com/tensorflow/tensorflow
!flatc --python --gen-object-api tensorflow/tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
import importlib.util
spec = importlib.util.spec_from_file_location("Model", "/content/tflite/Model.py")
Model = importlib.util.module_from_spec(spec)
spec.loader.exec_module(Model)
def CamelCaseToSnakeCase(camel_case_input):
"""Converts an identifier in CamelCase to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, attribute_name=None):
"""Converts a hierarchy of FB objects into a nested dict."""
if hasattr(fb, "__dict__"):
result = {}
for attribute_name in dir(fb):
attribute = fb.__getattribute__(attribute_name)
if not callable(attribute) and attribute_name[0] != "_":
snake_name = CamelCaseToSnakeCase(attribute_name)
result[snake_name] = FlatbufferToDict(attribute, snake_name)
return result
elif isinstance(fb, str):
return fb
elif attribute_name == "name" and fb is not None:
result = ""
for entry in fb:
result += chr(FlatbufferToDict(entry))
return result
elif hasattr(fb, "__len__"):
result = []
for entry in fb:
result.append(FlatbufferToDict(entry))
return result
else:
return fb
def CreateDictFromFlatbuffer(buffer_data):
model_obj = Model.Model.GetRootAsModel(buffer_data, 0)
model = Model.ModelT.InitFromObj(model_obj)
return FlatbufferToDict(model)
MODEL_ARCHIVE_NAME = 'inception_v3_2015_2017_11_10.zip'
MODEL_ARCHIVE_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/' + MODEL_ARCHIVE_NAME
MODEL_FILE_NAME = 'inceptionv3_non_slim_2015.tflite'
!curl -o {MODEL_ARCHIVE_NAME} {MODEL_ARCHIVE_URL}
!unzip {MODEL_ARCHIVE_NAME}
with open(MODEL_FILE_NAME, 'rb') as file:
model_data = file.read()
model_dict = CreateDictFromFlatbuffer(model_data)
pprint.pprint(model_dict['subgraphs'][0]['tensors'])
param_bytes = bytearray(model_dict['buffers'][212]['data'])
params = np.frombuffer(param_bytes, dtype=np.float32)
params.min()
params.max()
plt.figure(figsize=(8,8))
plt.hist(params, 100)
首先下载了Flatbuffers,用来解析tflite的模型。使用的依然是CMake编译,关于CMake,可以看看之前的文章:CMake小结_cmake 数组-CSDN博客
工具地址:https://github.com/google/flatbuffers/archive/v1.12.0.zip
之后根据tensorflow,生成反解代码。命令是:
!flatc --python --gen-object-api tensorflow/tensorflow/compiler/mlir/lite/schema/schema_v3.fbs
下载了Google的InceptionV3模型。这里没有用之前的MINIST模型,因为那个太简单。
MODEL_ARCHIVE_NAME = 'inception_v3_2015_2017_11_10.zip'
MODEL_ARCHIVE_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/' + MODEL_ARCHIVE_NAME
MODEL_FILE_NAME = 'inceptionv3_non_slim_2015.tflite'
之后就是解析:
model_dict = CreateDictFromFlatbuffer(model_data)
最后画出了权重的分布,可以看到,基本都是围绕在0附近,很收敛让尺寸很小。
最后还看了其它两个模型
# Text Classification
!wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite
# Post Estimation
!wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite
结果也都差不多。
text_classification:
posenet_mobilenet: