【深度学习】手写神经网络模型保存
文章目录开始保存读取模型具体如下使用完整代码参考pickle简单使用使用Python的shelve库。开始上篇写了如何手写神经网络,现在有如下三大需求:利用GPU加速,自动求导,保存读取模型。这里主要讲讲保存读取模型。GPU加速可以用 基于mxnet进行运算的minpy。保存读取模型这里使用一种较为简单的方法,直接保存训练后的网络对象到文件。学过Java的同学可能知道可以将对象实体固化到文件。这在
·
开始
上篇写了如何手写神经网络,现在有如下三大需求:利用GPU加速,自动求导,保存读取模型。这里主要讲讲保存读取模型。GPU加速可以用 基于mxnet
进行运算的minpy
,或者是CuPy。
保存读取模型
这里使用一种较为简单的方法,直接保存训练后的网络对象到文件。学过Java的同学可能知道可以将对象实体固化到文件。这在Python中可以使用Python的pickle模块
或者Python的shelve库
实现。
具体如下
如何想使用shelve库
实现可以参考如何将Python对象保存在本地文件中?。
def save(self,path):
obj = pickle.dumps(self)
with open(path,"wb") as f:
f.write(obj)
def load(path):
obj = None
with open(path, "rb") as f:
try:
obj = pickle.load(f)
except:
print("IOError")
return obj
使用
# 读取训练数据
trainDataPath = "./trainingDigits"
trainMat, trainLabels = handwritingData(trainDataPath)
testDataPath = "./testDigits"
testMat, testLabels = handwritingData(testDataPath)
net = Net()
net.setLearnrate(0.01)
net.train(trainMat, trainLabels, Epoch=2000)
net.save("hr.model")
net.test(testMat, testLabels)
newmodel = Net.load("hr.model")
newmodel.test(testMat, testLabels)
完整代码
#!/usr/bin/python3
# coding:utf-8
# @Author: Lin Misaka
# @File: net.py
# @Data: 2020/11/30
# @IDE: PyCharm
from os import listdir
import numpy as np
import matplotlib.pyplot as plt
import pickle
# 函数img2vector将图像转换为向量
def img2vector(filename):
returnVect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
# 读取手写字体txt数据
def handwritingData(dataPath):
hwLabels = []
FileList = listdir(dataPath) # 1 获取目录内容
m = len(FileList)
np.Mat = np.zeros((m, 1024))
for i in range(m):
# 2 从文件名解析分类数字
fileNameStr = FileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
np.Mat[i, :] = img2vector(dataPath + '/%s' % fileNameStr)
return np.Mat, hwLabels
# diff = True求导
def Sigmoid(x, diff=False):
def sigmoid(x): # sigmoid函数
return 1 / (1 + np.exp(-x))
def dsigmoid(x):
f = sigmoid(x)
return f * (1 - f)
if (diff == True):
return dsigmoid(x)
return sigmoid(x)
# diff = True求导
def SquareErrorSum(y_hat, y, diff=False):
if (diff == True):
return y_hat - y
return (np.square(y_hat - y) * 0.5).sum()
class Net():
def __init__(self):
# X Input
self.X = np.random.randn(1024, 1)
self.W1 = np.random.randn(16, 1024)
self.b1 = np.random.randn(16, 1)
self.W2 = np.random.randn(16, 16)
self.b2 = np.random.randn(16, 1)
self.W3 = np.random.randn(10, 16)
self.b3 = np.random.randn(10, 1)
self.alpha = 0.01 #学习率
self.losslist = [] #用于作图
def forward(self, X, y, activate):
self.X = X
self.z1 = np.dot(self.W1, self.X) + self.b1
self.a1 = activate(self.z1)
self.z2 = np.dot(self.W2, self.a1) + self.b2
self.a2 = activate(self.z2)
self.z3 = np.dot(self.W3, self.a2) + self.b3
self.y_hat = activate(self.z3)
Loss = SquareErrorSum(self.y_hat, y)
return Loss, self.y_hat
def backward(self, y, activate):
self.delta3 = activate(self.z3, True) * SquareErrorSum(self.y_hat, y, True)
self.delta2 = activate(self.z2, True) * (np.dot(self.W3.T, self.delta3))
self.delta1 = activate(self.z1, True) * (np.dot(self.W2.T, self.delta2))
dW3 = np.dot(self.delta3, self.a2.T)
dW2 = np.dot(self.delta2, self.a1.T)
dW1 = np.dot(self.delta1, self.X.T)
d3 = self.delta3
d2 = self.delta2
d1 = self.delta1
#update weight
self.W3 -= self.alpha * dW3
self.W2 -= self.alpha * dW2
self.W1 -= self.alpha * dW1
self.b3 -= self.alpha * d3
self.b2 -= self.alpha * d2
self.b1 -= self.alpha * d1
def setLearnrate(self, l):
self.alpha = l
def save(self,path):
obj = pickle.dumps(self)
with open(path,"wb") as f:
f.write(obj)
def load(path):
obj = None
with open(path, "rb") as f:
try:
obj = pickle.load(f)
except:
print("IOError")
return obj
def train(self, trainMat, trainLabels, Epoch=5, bitch=None):
for epoch in range(Epoch):
acc = 0.0
acc_cnt = 0
label = np.zeros((10, 1))#先生成一个10x1是向量,减少运算。用于生成one_hot格式的label
for i in range(len(trainMat)):#可以用batch,数据较少,一次训练所有数据集
X = trainMat[i, :].reshape((1024, 1)) #生成输入
labelidx = trainLabels[i]
label[labelidx][0] = 1.0
Loss, y_hat = self.forward(X, label, Sigmoid)#前向传播
self.backward(label, Sigmoid)#反向传播
label[labelidx][0] = 0.0#还原为0向量
acc_cnt += int(trainLabels[i] == np.argmax(y_hat))
acc = acc_cnt / len(trainMat)
self.losslist.append(Loss)
print("epoch:%d,loss:%02f,accrucy : %02f%%" % (epoch, Loss, acc*100))
self.plotLosslist(self.losslist, "Loss:Init->randn,alpha=0.01")
def plotLosslist(self, Loss, title):
font = {'family': 'simsun',
'weight': 'bold',
'size': 20,
}
m = len(Loss)
X = range(m)
# plt.figure(1)
plt.subplots(nrows=1, ncols=1, figsize=(10, 8))
plt.subplot(111)
plt.title(title, font)
plt.plot(X, Loss)
plt.xlabel(r'Epoch', font)
plt.ylabel(u'Loss', font)
plt.show()
def test(self, testMat, testLabels, bitch=None):
acc = 0.0
acc_cnt = 0
label = np.zeros((10, 1))#先生成一个10x1是向量,减少运算。用于生成one_hot格式的label
if(bitch == None):
bitch = len(testMat)
for i in range(bitch):#可以用batch,数据较少,一次训练所有数据集
X = testMat[i, :].reshape((1024, 1)) #生成输入
labelidx = testLabels[i]
label[labelidx][0] = 1.0
Loss, y_hat = self.forward(X, label, Sigmoid)#前向传播
label[labelidx][0] = 0.0#还原为0向量
acc_cnt += int(testLabels[i] == np.argmax(y_hat))
acc = acc_cnt / bitch
print("test num: %d, accrucy : %05.3f%%"%(bitch,acc*100))
# 读取训练数据
trainDataPath = "./trainingDigits"
trainMat, trainLabels = handwritingData(trainDataPath)
testDataPath = "./testDigits"
testMat, testLabels = handwritingData(testDataPath)
net = Net()
net.setLearnrate(0.01)
net.train(trainMat, trainLabels, Epoch=2000)
net.save("hr.model")
net.test(testMat, testLabels)
newmodel = Net.load("hr.model")
newmodel.test(testMat, testLabels)
参考
链接失效Call我。
github:https://github.com/MisakaMikoto128
个人网站
pickle简单使用
下面部分转自:Ellison张 - 侵删
import pickle
class A:
def __init__(self,name,a):
self.name=name
self.a=a
def rewrite(self,args):
self.a=args
'''
'''
#将对象使用pickle模块转换成二进制文件然后写入文件中
#但此种方法无法直接更新对象文件的信息
#此时存入文件的应该是对象本身而不是内存地址
obj1=A("qw","1")
obj1=pickle.dumps(obj1)
with open("userinfo","ab")as f:
f.write(obj1)
obj2=A("qa","2")
obj2=pickle.dumps(obj2)
with open("userinfo","ab")as f:
f.write(obj2)
obj3=A("qs","3")
obj3=pickle.dumps(obj3)
with open("userinfo","ab")as f:
f.write(obj3)
#读取文件中的对象文件
#pickle.load()一次只读取一个对象文件
f=open("userinfo","rb")
while 1:
try:
obj = pickle.load(f)
print(obj.a,obj)
except:
break
f.close()
f=open("userinfo","rb")
while 1:
try:
obj = pickle.load(f)
print(obj.a,obj)
except:
break
f.close()
可以将对象放进列表中再写入文件
如需修改对象时,将列表中的对象修改后再放回列表中最后再重新写入文件中
使用Python的shelve库。
其和Python内置的数据结构“字典”操作很类似不同点是shelve存储在外部文件中而不是存储在计算机内存中。
import shelve
class A:
def __init__(self,name,a):
self.name=name
self.a=a
def rewrite(self,args):
self.a=args
obj1=A("qw","1")
obj2=A("qa","2")
obj3=A("qs","3")
#写入文件
db=shelve.open("userinfo1")
db["qw"]=obj1
db["qa"]=obj2
db["qs"]=obj3
db.close()
#更新信息
db=shelve.open("userinfo1")
for k in db:
print(db[k].a)
a=db[k]
a.a=12
db[k]=a
db.close()
db=shelve.open("userinfo1")
for k in db:
print(db[k].a)
db.close()
更多推荐
已为社区贡献1条内容
所有评论(0)