CNN训练MNIST数据集tenflow2(中)
紧接上文(上),我们已经做好了训练,分别保存了32bit、16bit、8bit的量化模型。本篇的工作是加载模型,查看增加噪声、权重量化时的精度变化。数据流量化的结果将在下篇介绍。
一、加载数据集
# 1. 加载数据集
import tensorflow as tf
import numpy as np
minst = tf.keras.datasets.mnist
img_rows,img_cols = 28,28
(x_train, y_train), (x_test, y_test) = minst.load_data()
x_train = x_train.reshape(x_train.shape[0],img_rows,img_cols,1)
x_test = x_test.reshape(x_test.shape[0],img_rows,img_cols,1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255
x_test = x_test / 255
y_train_onehot = tf.keras.utils.to_categorical(y_train)
y_test_onehot = tf.keras.utils.to_categorical(y_test)
二、加载模型和权重,测试准确率
# 2. 加载模型和权重,测试准确率
model = tf.keras.Sequential()
model = tf.keras.models.load_model('models/mnist_tf2_fw.h5')
score = model.evaluate(x_test, y_test_onehot, verbose=0)
print('Test accuracy:', "{:.5f}".format(score[1]))
三、增加噪声,测试准确率
model = tf.keras.models.load_model('models/mnist_tf2_fw.h5')
# 这里0,2,5层分别表示conv1,conv2,fc层的权重
for i in [0,2,5]:
# 通过get_weights 和setweights修改权重
temp = model.layers[i].get_weights()[0]
# 0.15表征噪声的标准差
temp = temp + np.random.rand(*temp.shape)*0.15
model.layers[i].set_weights([temp])
score = model.evaluate(x_test, y_test_onehot, verbose=0)
print('Test accuracy:', "{:.5f}".format(score[1]))
四、权重量化,测试准确率
# 8bit量化
bit = 8
model = tf.keras.models.load_model('models/mnist_tf2_fw.h5')
# 这里0,2,5层分别表示conv1,conv2,fc层的权重
for i in [0,2,5]:
temp = model.layers[i].get_weights()[0]
temp = np.rint(temp * np.power(2,bit))/np.power(2,bit)
model.layers[i].set_weights([temp])
score = model.evaluate(x_test, y_test_onehot, verbose=0)
print('Test accuracy:', "{:.5f}".format(score[1]))
统计结果如下:
噪声标准差 | 0 | 0.04 | 0.06 | 0.08 | 0.09 | 0.10 | 0.16 |
---|---|---|---|---|---|---|---|
准确率 | 98.500 | 96.870 | 92.590 | 87.510 | 78.590 | 70.180 | 52.800 |
量化精度 | 32bit | 16bit | 8bit | 4bit | 3bit | 2bit | 1bit |
---|---|---|---|---|---|---|---|
准确率 | 98.500 | 98.500 | 98.490 | 98.410 | 97.380 | 89.400 | 66.250 |
特别说明的是:这里只是对权重做量化,推理计算过程中的数据还是32bit的浮点数,所以结果上看,3bit量化的精度下降不多,只有约1%。
有意思的一点是:3bit量化的准确率和0.04噪声接近;2bit量化准确率和0.08噪声相近,1bit量化的准确率和0.16的噪声相当。
分析原因:3bit量化相当于最小精度是0.125,可以将其等效为$3\sigma$的偏差,求得标准差为0.041,即为0.04的噪声。