【Tensorflow2】数据读取
文章目录数据读取创建一个 Dataset转换(Transform)到新的 DatasetMapBatchShuffleRepeat参考数据读取使用 tf.data.Dataset API创建一个 Dataset首先从内存中读取数据import tensorflow as tfdataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])如果想输出 d
·
数据读取
使用 tf.data.Dataset
API
创建一个 Dataset
首先从内存中读取数据
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
如果想输出 dataset
中的数据
for element in dataset:
print(element)
> tf.Tensor(1, shape=(), dtype=int32)
> tf.Tensor(2, shape=(), dtype=int32)
> tf.Tensor(3, shape=(), dtype=int32)
如果只想输出数值
for element in dataset:
print(element.numpy())
> 1
> 2
> 3
再以读取 mnist
数据集为例
train, test = tf.keras.datasets.mnist.load_data()
images, labels = train
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
转换(Transform)到新的 Dataset
使用 Dataset.xx()
Map
map(
map_func, num_parallel_calls=None, deterministic=None
)
对每个元素+1
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x + 1)
print(list(dataset.as_numpy_iterator()))
> [2, 3, 4]
Batch
batch(
batch_size, drop_remainder=False
)
drop_remainder=True
会将最后一个不整的 batch 删掉
Shuffle
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None
)
buffer_size 需要大于等于数据大小
reshuffle_each_iteration=True 让每个epoch都不一样
Repeat
repeat(
count=None
)
count 是重复的次数,空着会无限循环
以上四种操作都用到的写法
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(1000).batch(32).repeat(2)
参考
更多推荐
所有评论(0)