将一张图片转换为TFRecord 文件

#!/usr/bin/env python
# -*- codint: utf-8 -*-
#author: zhenghan time:2020/7/7

import tensorflow as tf


def write_test(input, output):
    # 借助于TFRecordWriter 才能将信息写入TFRecord 文件
    writer = tf.python_io.TFRecordWriter(output)

    # 读取图片并进行解码
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)

    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 将图片转换成string
        image_data = image.tostring()
        print(type(image))
        print(len(image_data))
        name = bytes('cat', encoding='utf-8')
        print(type(name))
        # 创建Example对象,并将Feature一一对应填充进去
        example = tf.train.Example(features=tf.train.Features(feature={
            'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
            'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
            'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))
        # 将example序列化成string 类型,然后写入。
        writer.write(example.SerializeToString())
    writer.close()


if __name__ == '__main__':
    input_photo = 'img1.png'
    output_file = 'img1.tfrecord'
    write_test(input_photo, output_file)

将TFRecord 文件转化为图片

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_photo, features=features)
    return parsed_features


def read_test(input_file):
    # 用dataset读取TFRecords文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_one_shot_iterator()

    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode()
        img_data = features['data']
        shape = features['shape']
        print("==============")
        print(type(shape))
        print(len(img_data))

        # 从bytes数组中加载图片原始数据,并重新reshape,它的结果是 ndarray 数组
        img_data = np.fromstring(img_data, dtype=np.uint8)
        image_data = np.reshape(img_data, shape)

        plt.figure()
        # 显示图片
        plt.imshow(image_data)
        plt.show()

        # 将数据重新编码成jpg图片并保存
        img = tf.image.encode_jpeg(image_data)
        tf.gfile.GFile('cat_encode.png', 'wb').write(img.eval())


if __name__ == '__main__':
    read_test("img1.tfrecord")
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐