开始

上篇写了如何手写神经网络,现在有如下三大需求:利用GPU加速,自动求导,保存读取模型。这里主要讲讲保存读取模型。GPU加速可以用 基于mxnet进行运算的minpy,或者是CuPy。

保存读取模型

这里使用一种较为简单的方法,直接保存训练后的网络对象到文件。学过Java的同学可能知道可以将对象实体固化到文件。这在Python中可以使用Python的pickle模块或者Python的shelve库实现。

具体如下

如何想使用shelve库实现可以参考如何将Python对象保存在本地文件中?

def save(self,path):
    obj = pickle.dumps(self)
    with open(path,"wb") as f:
        f.write(obj)

def load(path):
    obj = None
    with open(path, "rb") as f:
        try:
            obj = pickle.load(f)
        except:
            print("IOError")
    return obj

使用

# 读取训练数据
trainDataPath = "./trainingDigits"
trainMat, trainLabels = handwritingData(trainDataPath)
testDataPath = "./testDigits"
testMat, testLabels = handwritingData(testDataPath)
net = Net()
net.setLearnrate(0.01)
net.train(trainMat, trainLabels, Epoch=2000)
net.save("hr.model")
net.test(testMat, testLabels)

newmodel = Net.load("hr.model")
newmodel.test(testMat, testLabels)

在这里插入图片描述

完整代码

#!/usr/bin/python3
# coding:utf-8
# @Author: Lin Misaka
# @File: net.py
# @Data: 2020/11/30
# @IDE: PyCharm
from os import listdir
import numpy as np
import matplotlib.pyplot as plt
import pickle
# 函数img2vector将图像转换为向量
def img2vector(filename):
    returnVect = np.zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])
    return returnVect


# 读取手写字体txt数据
def handwritingData(dataPath):
    hwLabels = []
    FileList = listdir(dataPath)  # 1 获取目录内容
    m = len(FileList)
    np.Mat = np.zeros((m, 1024))
    for i in range(m):
        # 2 从文件名解析分类数字
        fileNameStr = FileList[i]
        fileStr = fileNameStr.split('.')[0]  # take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        np.Mat[i, :] = img2vector(dataPath + '/%s' % fileNameStr)
    return np.Mat, hwLabels



# diff = True求导
def Sigmoid(x, diff=False):
    def sigmoid(x):        # sigmoid函数
        return 1 / (1 + np.exp(-x))
    def dsigmoid(x):
        f = sigmoid(x)
        return f * (1 - f)
    if (diff == True):
        return dsigmoid(x)
    return sigmoid(x)

# diff = True求导
def SquareErrorSum(y_hat, y, diff=False):
    if (diff == True):
        return y_hat - y
    return (np.square(y_hat - y) * 0.5).sum()


class Net():
    def __init__(self):
        # X Input
        self.X =  np.random.randn(1024, 1)
        self.W1 = np.random.randn(16, 1024)
        self.b1 = np.random.randn(16, 1)
        self.W2 = np.random.randn(16, 16)
        self.b2 = np.random.randn(16, 1)
        self.W3 = np.random.randn(10, 16)
        self.b3 = np.random.randn(10, 1)
        self.alpha = 0.01  #学习率
        self.losslist = [] #用于作图

    def forward(self, X, y, activate):
        self.X = X
        self.z1 = np.dot(self.W1, self.X) + self.b1
        self.a1 = activate(self.z1)
        self.z2 = np.dot(self.W2, self.a1) + self.b2
        self.a2 = activate(self.z2)
        self.z3 = np.dot(self.W3, self.a2) + self.b3
        self.y_hat = activate(self.z3)
        Loss = SquareErrorSum(self.y_hat, y)
        return Loss, self.y_hat

    def backward(self, y, activate):
        self.delta3 = activate(self.z3, True) * SquareErrorSum(self.y_hat, y, True)
        self.delta2 = activate(self.z2, True) * (np.dot(self.W3.T, self.delta3))
        self.delta1 = activate(self.z1, True) * (np.dot(self.W2.T, self.delta2))
        dW3 = np.dot(self.delta3, self.a2.T)
        dW2 = np.dot(self.delta2, self.a1.T)
        dW1 = np.dot(self.delta1, self.X.T)
        d3 = self.delta3
        d2 = self.delta2
        d1 = self.delta1
        #update weight
        self.W3 -= self.alpha * dW3
        self.W2 -= self.alpha * dW2
        self.W1 -= self.alpha * dW1
        self.b3 -= self.alpha * d3
        self.b2 -= self.alpha * d2
        self.b1 -= self.alpha * d1

    def setLearnrate(self, l):
        self.alpha = l

    def save(self,path):
        obj = pickle.dumps(self)
        with open(path,"wb") as f:
            f.write(obj)

    def load(path):
        obj = None
        with open(path, "rb") as f:
            try:
                obj = pickle.load(f)
            except:
                print("IOError")
        return obj

    def train(self, trainMat, trainLabels, Epoch=5, bitch=None):
        for epoch in range(Epoch):
            acc = 0.0
            acc_cnt = 0
            label = np.zeros((10, 1))#先生成一个10x1是向量,减少运算。用于生成one_hot格式的label
            for i in range(len(trainMat)):#可以用batch,数据较少,一次训练所有数据集
                X = trainMat[i, :].reshape((1024, 1)) #生成输入

                labelidx = trainLabels[i]
                label[labelidx][0] = 1.0

                Loss, y_hat = self.forward(X, label, Sigmoid)#前向传播
                self.backward(label, Sigmoid)#反向传播

                label[labelidx][0] = 0.0#还原为0向量
                acc_cnt += int(trainLabels[i] == np.argmax(y_hat))

            acc = acc_cnt / len(trainMat)
            self.losslist.append(Loss)
            print("epoch:%d,loss:%02f,accrucy : %02f%%" % (epoch, Loss, acc*100))
        self.plotLosslist(self.losslist, "Loss:Init->randn,alpha=0.01")

    def plotLosslist(self, Loss, title):
        font = {'family': 'simsun',
                'weight': 'bold',
                'size': 20,
                }
        m = len(Loss)
        X = range(m)
        # plt.figure(1)
        plt.subplots(nrows=1, ncols=1, figsize=(10, 8))
        plt.subplot(111)
        plt.title(title, font)
        plt.plot(X, Loss)
        plt.xlabel(r'Epoch', font)
        plt.ylabel(u'Loss', font)
        plt.show()

    def test(self, testMat, testLabels, bitch=None):
        acc = 0.0
        acc_cnt = 0
        label = np.zeros((10, 1))#先生成一个10x1是向量,减少运算。用于生成one_hot格式的label
        if(bitch == None):
            bitch = len(testMat)
        for i in range(bitch):#可以用batch,数据较少,一次训练所有数据集
            X = testMat[i, :].reshape((1024, 1)) #生成输入

            labelidx = testLabels[i]
            label[labelidx][0] = 1.0

            Loss, y_hat = self.forward(X, label, Sigmoid)#前向传播

            label[labelidx][0] = 0.0#还原为0向量
            acc_cnt += int(testLabels[i] == np.argmax(y_hat))
        acc = acc_cnt / bitch
        print("test num: %d, accrucy : %05.3f%%"%(bitch,acc*100))


# 读取训练数据
trainDataPath = "./trainingDigits"
trainMat, trainLabels = handwritingData(trainDataPath)
testDataPath = "./testDigits"
testMat, testLabels = handwritingData(testDataPath)
net = Net()
net.setLearnrate(0.01)
net.train(trainMat, trainLabels, Epoch=2000)
net.save("hr.model")
net.test(testMat, testLabels)

newmodel = Net.load("hr.model")
newmodel.test(testMat, testLabels)

参考

[1] 如何将Python对象保存在本地文件中?

链接失效Call我。
github:https://github.com/MisakaMikoto128
个人网站


pickle简单使用

在这里插入图片描述

下面部分转自:Ellison张 - 侵删

import pickle
class A:
    def __init__(self,name,a):
        self.name=name
        self.a=a
    def rewrite(self,args):
        self.a=args
'''

'''
#将对象使用pickle模块转换成二进制文件然后写入文件中
#但此种方法无法直接更新对象文件的信息
#此时存入文件的应该是对象本身而不是内存地址
obj1=A("qw","1")
obj1=pickle.dumps(obj1)
with open("userinfo","ab")as f:
    f.write(obj1)
obj2=A("qa","2")
obj2=pickle.dumps(obj2)
with open("userinfo","ab")as f:
    f.write(obj2)
obj3=A("qs","3")
obj3=pickle.dumps(obj3)
with open("userinfo","ab")as f:
    f.write(obj3)

#读取文件中的对象文件
#pickle.load()一次只读取一个对象文件
f=open("userinfo","rb")
while 1:
    try:
        obj = pickle.load(f)
        print(obj.a,obj)
    except:
        break
f.close()
f=open("userinfo","rb")
while 1:
    try:
        obj = pickle.load(f)
        print(obj.a,obj)
    except:
        break
f.close()

可以将对象放进列表中再写入文件
如需修改对象时,将列表中的对象修改后再放回列表中最后再重新写入文件中

使用Python的shelve库。

其和Python内置的数据结构“字典”操作很类似不同点是shelve存储在外部文件中而不是存储在计算机内存中。

import shelve
class A:
    def __init__(self,name,a):
        self.name=name
        self.a=a
    def rewrite(self,args):
        self.a=args
obj1=A("qw","1")
obj2=A("qa","2")
obj3=A("qs","3")
#写入文件
db=shelve.open("userinfo1")
db["qw"]=obj1
db["qa"]=obj2
db["qs"]=obj3
db.close()
#更新信息
db=shelve.open("userinfo1")
for k in db:
    print(db[k].a)
    a=db[k]
    a.a=12
    db[k]=a
db.close()
db=shelve.open("userinfo1")
for k in db:
    print(db[k].a)
db.close()

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐