Tensorflow-data
Tesorflow官方提供了一些数据集来供我们使用。其地址为Tensorflow-data,我们可以在这里查看数据集的详细细节Tensorflow-data-catalog。
其API地址为API。
简单使用
Tensorflow-data有137个API,一般不会全部使用。我们需要了解的第一个API是:
tfds.load
:从本地或网络载入数据集,并生成tf.data.Dataset
对象。
其基本形式为:
1 | tfds.load( |
tfds.load
实际上是一个对内建三个方法的语法糖:
新建一个
DatasetBuilder
:1
builder = tfds.builder(name, data_dir=data_dir, **builder_kwargs)
为这个build下载或准备(本地则不需要下载)数据集:
1
builder.download_and_prepare(**download_and_prepare_kwargs)
载入
tf.data.Dataset
对象:1
2
3
4
5
6
7
8ds = builder.as_dataset(
split=split,
as_supervised=as_supervised,
shuffle_files=shuffle_files,
read_config=read_config,
decoders=decoders,
**as_dataset_kwargs,
)
说完我们再来看每个参数解析:
名字 | 类型 | 说明 | 必须 | 默认值 |
---|---|---|---|---|
name | String | 数据集的名称,我们可以在Tensorflow-data-catalog查看数据集的名字。注意这里的名字是蛇形命名法,即:a_b_c的下划线分割的形式。 | 是 | / |
split | String/tuple | 如何分割数据集,主要有’train’, ‘test’两个是符串以及其tuple构成。如: 1: ‘train’, ‘test’,分别表示只载入训练数据,或只载入测试数据。 2: [‘train’, ‘test’],表示载入训练集和测试集,此时第一个参数会返回一个list,我们可以用(train_data, test_data)来解构list。 3: ‘train[:120]’, ‘train[:75%]’, ‘test[25%:100%]’, ‘train[:4shard]’:如同python的list截取一样的格式,只不过这里可以是绝对值、者百分比值或者分片。片是tensorflow中的一个概念,数据集在下载时即定义好了一共多少个片,我们可以用 info.splits['train'].num_shards 来查看总片数。4: ‘train+test’, ‘train[:25%]+test’: 这种格式会把训练集和测试集合并在一起返回,其中每一部分都可以用列表截断。 值得注意的是:每个数据集支持的split的格式不同,需要我们具体来使用。 |
否 | 数据集定义的划分模式 |
data_dir | String | 写入/读取数据集的地址,我们可以更改这个地址。 | 否 | win下:C:\Users\[user]\tensorflow_datasets |
batch_size | Int | 设置一个batch的大小,这里的batch主要是针对较大的数据集,不可能将所有的数据全部读进内存,我们会一次读取一个batch进内存运算。 | 否 | 不划分batch |
shuffle_files | Boolean | 是否需要打乱输入数据 | 否 | False |
download | Boolean | 是否是从远程下载 | 否 | True |
as_supervised | Boolean | 是否是监督模式,如果是True的话, tf.data.Dataset |
||
decoders | ||||
read_config | ||||
with_info | ||||
builder_kwargs | ||||
download_and_prepare_kwargs | ||||
as_dataset_kwargs | ||||
try_gcs |