import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split

一、Tensorflow的多层次编程API

在这里插入图片描述
官方建议:使用Estimator API和dataset API训练模型。
Estimator API作为高阶API,它会处理初始化、日志记录、保存和恢复等细节部分,并具有很多其他功能,以便你可以专注于模型。

二、使用预定义Estimator的步骤

1.创建一个或多个输入函数;

2.定义模型的Feature column;

3.实例化 Estimator,指定特征列和各种超参数;

4.在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源;

三、使用预定义Estimator对鸢尾花进行分类

iris = datasets.load_iris()
train_x,test_x,train_y,test_y= train_test_split(iris.data,iris.target,test_size=0.3,random_state=0)

1.创建输入函数

训练集输入函数

def train_input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    return dataset.shuffle(1000).repeat().batch(batch_size) # 由于repeat()未指定参数,因此会一直循环下去

验证集输入函数

def eval_input_fn(features, labels, batch_size):
    features=dict(features)
    if labels is None:
        inputs = features
    else:
        inputs = (features, labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    assert batch_size is not None, "batch_size must not be None"
    dataset = dataset.batch(batch_size) # 没有repeat(),如果只会遍历数据集一次
    return dataset

2.定义特征列

feature_names = ['SepalLength','SepalWidth','PetalLength','PetalWidth']
my_feature_columns = []
for key in feature_names:
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

3.实例化Estimator

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

4.训练模型

# 输入函数需要字典类型的数据,key是特征名称,value是对应的数据
tr_x = {feature_name:train_x[:,i] for i,feature_name in enumerate(feature_names)}
classifier.train(input_fn=lambda:train_input_fn(tr_x, train_y, batch_size=16),steps=20)
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp545nieyg/model.ckpt.
INFO:tensorflow:loss = 40.29106, step = 1
INFO:tensorflow:Saving checkpoints for 20 into /tmp/tmp545nieyg/model.ckpt.
INFO:tensorflow:Loss for final step: 9.951749.

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fef605258d0>

5.评估模型

te_x = {feature_name:test_x[:,i] for i,feature_name in enumerate(feature_names)}
eval_result = classifier.evaluate(input_fn=lambda:eval_input_fn(te_x, test_y, batch_size=32))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-06-06T16:29:49Z
INFO:tensorflow:Graph was finalized.
WARNING:tensorflow:From /opt/conda/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from /tmp/tmp545nieyg/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2019-06-06-16:29:49
INFO:tensorflow:Saving dict for global step 20: accuracy = 0.9111111, average_loss = 0.49284205, global_step = 20, loss = 11.088946
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/tmp545nieyg/model.ckpt-20

Test set accuracy: 0.911

6.使用模型进行预测

template = ('Prediction is "{}" ({:.1f}%), expected "{}"')
predictions = classifier.predict(input_fn=lambda:eval_input_fn(te_x,labels=None,batch_size=32))
for pred_dict,y_true in zip(predictions,test_y):
    # pred_dict是一个字典,其keys包含:logits、probabilities、class_ids、classes
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]
    print(template.format(iris.target_names[class_id],100 * probability, iris.target_names[y_true]))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp545nieyg/model.ckpt-20
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "virginica" (62.7%), expected "virginica"
Prediction is "versicolor" (57.1%), expected "versicolor"
Prediction is "setosa" (87.4%), expected "setosa"
Prediction is "virginica" (60.0%), expected "virginica"
Prediction is "setosa" (81.1%), expected "setosa"
Prediction is "virginica" (69.2%), expected "virginica"
Prediction is "setosa" (81.4%), expected "setosa"
Prediction is "versicolor" (53.8%), expected "versicolor"
Prediction is "versicolor" (54.9%), expected "versicolor"
Prediction is "versicolor" (55.0%), expected "versicolor"
Prediction is "virginica" (61.0%), expected "virginica"
Prediction is "versicolor" (51.7%), expected "versicolor"
Prediction is "versicolor" (48.2%), expected "versicolor"
Prediction is "versicolor" (53.4%), expected "versicolor"
Prediction is "versicolor" (47.4%), expected "versicolor"
Prediction is "setosa" (83.8%), expected "setosa"
Prediction is "versicolor" (48.3%), expected "versicolor"
Prediction is "virginica" (46.3%), expected "versicolor"
Prediction is "setosa" (72.9%), expected "setosa"
Prediction is "setosa" (84.1%), expected "setosa"
Prediction is "virginica" (58.9%), expected "virginica"
Prediction is "virginica" (46.6%), expected "versicolor"
Prediction is "setosa" (81.4%), expected "setosa"
Prediction is "setosa" (76.4%), expected "setosa"
Prediction is "virginica" (49.4%), expected "virginica"
Prediction is "setosa" (83.5%), expected "setosa"
Prediction is "setosa" (84.3%), expected "setosa"
Prediction is "versicolor" (52.9%), expected "versicolor"
Prediction is "versicolor" (50.7%), expected "versicolor"
Prediction is "setosa" (79.1%), expected "setosa"
Prediction is "virginica" (56.9%), expected "virginica"
Prediction is "virginica" (48.5%), expected "versicolor"
Prediction is "setosa" (81.4%), expected "setosa"
Prediction is "virginica" (51.0%), expected "virginica"
Prediction is "virginica" (62.4%), expected "virginica"
Prediction is "versicolor" (45.7%), expected "versicolor"
Prediction is "setosa" (82.2%), expected "setosa"
Prediction is "virginica" (55.3%), expected "versicolor"
Prediction is "versicolor" (49.6%), expected "versicolor"
Prediction is "versicolor" (52.7%), expected "versicolor"
Prediction is "virginica" (55.1%), expected "virginica"
Prediction is "setosa" (79.2%), expected "setosa"
Prediction is "virginica" (48.7%), expected "virginica"
Prediction is "setosa" (78.6%), expected "setosa"
Prediction is "setosa" (83.0%), expected "setosa"
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐