手写字符识别

手写字符的数据集

特征值

训练集的大小

目标值

对应的API

网络设计

准确率的计算

代码:

#导入tensorflow
import tensorflow as tf
#将tensorflow转换成1版本的
tf.compat.v1.disable_eager_execution()
tf = tf.compat.v1

#导入手写字符模块
from tensorflow.examples.tutorials.mnist import input_data

def full_connection():
    '''
    用全连接层来对手写数字进行识别
    :return:
    '''
    # 1 准备数据
    mnist = input_data.read_data_sets("./mnist_data/", one_hot=True)
    #获取特征值和目标值的真实数据(x是特征值,y是目标值),先使用占位符
    x = tf.placeholder(dtype=tf.float32,shape=[None,784])
    y_true = tf.placeholder(dtype=tf.float32,shape=[None,10])
    # 2 构建模型
    #权重
    Weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]))
    #偏置
    bias = tf.Variable(initial_value=tf.random_normal(shape=[10]))
    #预测值
    y_predict = tf.matmul(x,Weights)+bias
    # 3 构建损失函数
    #用softmax函数
    error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
    # 4 优化损失
    #用梯度下降优化损失
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(error)

    # 5 准确率计算,
     #(1)比较输出的结果最大值所在位置和真实值的最大值所在位置
    equal_list = tf.equal(tf.argmax(y_predict,1),tf.argmax(y_true,1))
    #(2)求平均
    accuracy = tf.reduce_mean(tf.cast(equal_list,tf.float32))

    #初始化变量
    init = tf.global_variables_initializer()

    #开启会话
    with tf.Session() as sess:
        sess.run(init)

        #feed占位符
        image,label = mnist.train.next_batch(100)

        print("训练之前,损失为%f," % sess.run(error,feed_dict={x:image,y_true:label}))

        #开始训练
        for i in range(800):
            _,loss,accuracy_value = sess.run([optimizer,error,accuracy],feed_dict={x:image,y_true:label})
            print("第%d次的训练,损失为%f,准确率为%f" %(i+1,loss,accuracy_value))

    return Nones

输出:

 

 

 

 

 

 

 

 

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐