数据读取

使用 tf.data.Dataset API

创建一个 Dataset

首先从内存中读取数据

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

如果想输出 dataset 中的数据

for element in dataset:
  print(element)

> tf.Tensor(1, shape=(), dtype=int32)
> tf.Tensor(2, shape=(), dtype=int32)
> tf.Tensor(3, shape=(), dtype=int32)

如果只想输出数值

for element in dataset:
  print(element.numpy())

> 1
> 2
> 3

再以读取 mnist 数据集为例

train, test = tf.keras.datasets.mnist.load_data()
images, labels = train
dataset = tf.data.Dataset.from_tensor_slices((images, labels))

转换(Transform)到新的 Dataset

使用 Dataset.xx()

Map

map(
    map_func, num_parallel_calls=None, deterministic=None
)

对每个元素+1

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x + 1)
print(list(dataset.as_numpy_iterator()))

> [2, 3, 4]

Batch

batch(
    batch_size, drop_remainder=False
)

drop_remainder=True 会将最后一个不整的 batch 删掉

Shuffle

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)

buffer_size 需要大于等于数据大小

reshuffle_each_iteration=True 让每个epoch都不一样

Repeat

repeat(
    count=None
)

count 是重复的次数,空着会无限循环

以上四种操作都用到的写法

dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(1000).batch(32).repeat(2)

参考

tf.data.Dataset

tf.data: Build TensorFlow input pipelines

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐