零基础入门深度学习(2) - 线性单元和梯度下降 python3代码
#!/usr/bin/env python# -*- coding: UTF-8 -*-#原文 https://www.zybuluo.com/hanbingtao/note/448086from perceptron import Percepronimport matplotlib.pyplot as plt#定义激活函数ff = lambda x: xclass Linea...
·
#!/usr/bin/env python # -*- coding: UTF-8 -*- #原文 https://www.zybuluo.com/hanbingtao/note/448086 from perceptron import Percepron import matplotlib.pyplot as plt #定义激活函数f f = lambda x: x class LinearUnit(Percepron): #继承感知机 def __init__(self, input_num): '''初始化线性单元,设置输入参数的个数''' Percepron.__init__(self, input_num, f) def get_training_dataset(): ''' 捏造5个人的收入数据 ''' # 构建训练数据 # 输入向量列表,每一项是工作年限 input_vecs = [[5], [3], [8], [1.4], [10.1]] # 期望的输出列表,月薪,注意要与输入一一对应 labels = [5500, 2300, 7600, 1800, 11400] return input_vecs, labels def train_linear_unit(): ''' 使用数据训练线性单元 ''' # 创建感知器,输入参数的特征数为1(工作年限) lu = LinearUnit(1) # 训练,迭代10轮, 学习速率为0.01 input_vecs, labels = get_training_dataset() lu.train(input_vecs, labels, 10, 0.01) #返回训练好的线性单元 return lu def plot(linear_unit): input_vecs, labels = get_training_dataset() fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(map(lambda x: x[0], input_vecs), labels) weights = linear_unit.weights bias = linear_unit.bias x = range(0,12,1) y = map(lambda x:weights[0] * x + bias, x) ax.plot(x, y) plt.show() if __name__ == '__main__': '''训练线性单元''' linear_unit = train_linear_unit() # 打印训练获得的权重 print("-----000----------") print(linear_unit) print("-----------------111-----") # 测试 print('Work 3.4 years, monthly salary = %.2f' % linear_unit.predict([3.4])) print('Work 15 years, monthly salary = %.2f' % linear_unit.predict([15])) print('Work 1.5 years, monthly salary = %.2f' % linear_unit.predict([1.5])) print('Work 6.3 years, monthly salary = %.2f' % linear_unit.predict([6.3])) #plot(linear_unit) input_vecs, labels = get_training_dataset() fig = plt.figure() ax = fig.add_subplot(111) tt=list(map(lambda x: x[0], input_vecs)) ax.scatter(tt, labels) weights = linear_unit.weights bias = linear_unit.bias x = range(0, 12, 1) y = list(map(lambda x: weights[0] * x + bias, x)) ax.plot(x, y) plt.show()
更多推荐



所有评论(0)