下面是配置文件,建议单独弄一个文件,可以随时改。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/4/6 23:05
# @Author  : Zehan Song
# @Site    : 
# @File    : configs.py
# @Software: PyCharm
import warnings
class Config(object):
    INPUT_NODE = 784
    OUTPUT_NODE = 10
    LAYER1_NODE = 500
    BATCH_SIZE = 64
    SAVE_PATH = './model_path/'
    MAX_EPOCH = 100
    LEARNING_BASE = 0.05
    DECAY_RATE = 0.99
    MOVING_AVERAGE_DECAY = 0.99
    MODEL_NAME = "two_linear_layers_mnist.ckpt"
    REGULARIZATION_RATE = 0.0001

def Parse(self,kwargs):#parse the dict
    print "user config:"
    for k,v in kwargs.items():
        if not hasattr(self,k):
            warnings.warn(KeyError)
        else:
            setattr(self,k,v)
    for k,v in self.__class__.__dict__.items():
        if not (k.startswith("__") or k.startswith("Parse")):
            print (k,getattr(self,k))

Config.Parse = Parse#class attribute instead of object attribute
configs= Config()

之后是训练模型的代码,有不懂的问题可以在下面评论。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/4/6 23:05
# @Author  : Zehan Song
# @Site    : 
# @File    : best_tensorflow_practice_example.py
# @Software: PyCharm
from configs import configs
import tensorflow as tf
import time
from tqdm import tqdm
import fire
from tensorflow.examples.tutorials.mnist import input_data
# inference
def get_weight_variable(shape,regularizer):
    weights = tf.get_variable(
        name="weights",shape=shape,initializer=tf.truncated_normal_initializer(stddev=0.1)
    )
    if regularizer != None:
        tf.add_to_collection("losses",regularizer(weights))
    return weights

def inference(input_tensor,regularizer):
    # first layer
    # if IsTrain:
    #     reuse = False
    #     regularizer = regularizer
    # else:
    #     reuse = True
    #     regularizer = None
    with tf.variable_scope('layer1'):
        weights = get_weight_variable([configs.INPUT_NODE,configs.LAYER1_NODE],regularizer)
        biases = tf.get_variable("biases",[configs.LAYER1_NODE],initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor,weights)+biases)
    # second layer
    with tf.variable_scope('layer2'):
        weights = get_weight_variable([configs.LAYER1_NODE,configs.OUTPUT_NODE],regularizer)
        biases = tf.get_variable("biases",[configs.OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1,weights)+biases

    return layer2

def train(**kwargs):
    configs.Parse(kwargs)
    # set placeholder,regularizer
    x = tf.placeholder(dtype=tf.float32, shape=[None, configs.INPUT_NODE], name='x-input')
    y = tf.placeholder(dtype=tf.float32, shape=[None, configs.OUTPUT_NODE], name='y-input')
    # flag = tf.placeholder(dtype=tf.bool,shape=[1],name='flag')
    regularizer = tf.contrib.layers.l2_regularizer(configs.REGULARIZATION_RATE)

    #prepare data
    mnist = input_data.read_data_sets("./data",one_hot=True)
    validation_feed = {x:mnist.validation.images,y:mnist.validation.labels}
    test_feed = {x:mnist.test.images,y:mnist.test.labels}

    #inference
    y_ = inference(x,regularizer)
    # set useful tricks
    global_step = tf.Variable(0,trainable=False)
    lr = tf.train.exponential_decay(configs.LEARNING_BASE,global_step,mnist.train.num_examples/configs.BATCH_SIZE,configs.DECAY_RATE)
    saver = tf.train.Saver()

    # compute loss,accuracy and set optimizer,remember to add the regularization loss
    cross_entropy_mean = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits\
                                            (labels=tf.argmax(y,1),logits=y_))
    loss = cross_entropy_mean + tf.add_n(tf.get_collection("losses"))
    optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss,global_step=global_step)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(y_,1)),tf.float32))
    # start session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True#limit occupation of GPU
    config.allow_soft_placement = True#allow computing among different gpus and cpus

    with tf.Session(config=config).as_default() as sess:
        tf.global_variables_initializer().run()#very important!!!
        # start training
        train_op = tf.group([optimizer],name='no_op')
        # TRAINING_STEPS = mnist.train.num_examples/configs.BATCH_SIZE
        TRAINING_STEPS = 5000
        best_accuracy = 0
        # for epoch in range(configs.MAX_EPOCH):
        for i in tqdm(range(TRAINING_STEPS)):
            xs,ys = mnist.train.next_batch(configs.BATCH_SIZE)
            _,loss_value,step,accuracy_value = sess.run([train_op,loss,global_step,accuracy],\
                                                        feed_dict={x:xs,y:ys})
            if i%50 == 0:
                # saver.save(sess,"./saveforema.ckpt")
                print("iters[%d/%d],loss:%.6f,accuracy:%.6f"\
                      %(i,TRAINING_STEPS,loss_value,accuracy_value))

        # if epoch%10 == 0:#validation and don't add regularization loss
            # saver.restore(sess,"./saveforema.ckpt")#use ema
            if i%100 == 0:
                loss_value, accuracy_value = sess.run([cross_entropy_mean, accuracy], feed_dict=validation_feed)
                print("val_loss:%.6f,val_accuracy:%.6f"%(loss_value,accuracy_value))
                if best_accuracy<accuracy_value:
                    best_accuracy = accuracy_value
                    saver.save(sess, configs.SAVE_PATH + "best.cpkt")
            #save model
            # print("validation accuracy:%.6f validation loss:%.6f"%(accuracy_value,loss_value))


    #start test
        ckpt = tf.train.get_checkpoint_state(configs.SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess,ckpt.model_checkpoint_path)
            loss_value,accuracy_value = sess.run([cross_entropy_mean,accuracy],feed_dict=test_feed)
            print("test accuracy:%.6f test loss:%.6f" % (accuracy_value, loss_value))

if __name__ == "__main__":
    fire.Fire()
#input the following line on your server and run the above code
#python best_tensorflow_practice_example.py train --BATCH_SIZE=100

下面是运行过程截图



Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐