5.9.3. Dataset API

5.9.3. Dataset API

仓颉 TensorBoost 支持加载 AI 领域常用的数据集,用户可以直接调用 dataset 包中的类实现数据集的加载。目前支持的数据集如下表所示:

数据集数据集类数据集简介
MNISTMnistDatasetMNIST 是一个大型手写数字图像数据集,拥有 60,000 张训练图像和 10,000 张测试图像,常用于训练各种图像处理系统。
CIFAR-10Cifar10DatasetCIFAR-10 是一个微小图像数据集,包含 10 种类别下的 60,000 张 32×32 大小彩色图像,平均每种类别 6,000 张,其中 5,000 张为训练集,1,000 张为测试集。

仓颉 TensorBoost 也支持加载多种存储格式下的数据集,用户可以直接调用 dataset 包中的类实现数据集的加载,目前支持的数据格式如下表所示:

数据集数据集类数据集简介
TFRecordTFRecordDatasetTFRecord 是 TensorFlow 定义的一种二进制数据文件格式。
MindRecordMindDataDatasetMindRecord 是 mindspore 定义的一种二进制数据文件格式。
ImageFolderImageFolderDatasetImageFolder 是图片文件格式。

数据集

Dataset 基类

仓颉 TensorBoost 提供数据加载与处理的基类 Dataset,Dataset 提供了统一的数据操作与管理方式,不同格式的数据集均继承自 Dataset。直接创建 Dataset 的对象是无意义的,加载数据集需要创建 Dataset 相应的子类。

public open class Dataset {
    public init()
    public init(epoch: Int32, columnNames: Array<String>)
    public func repeat(count: Int32)
    public func batch(batchSize: Int32, dropRemainder: Bool)
    public func shuffle(bufferSize: Int32)
    public func datasetMap(dsOpHandleList: Array<MSDsOperationHandle>, name: String)
    public func getDatasetSize()
    public func getNext(params: Array<Tensor>)
}

初始化参数列表:

名称含义
epoch类型为 Int32,数据集要循环的 epoch 数
columnNames类型为 String 数组,数据集文件的列名

设置重复次数

public func repeat(count: Int32): Unit

设置数据集重复的次数,达到扩充数据量的目的。

参数列表:

名称含义
count类型为 Int32,数据集重复的次数

【注意】 repeat 和 batch 操作的顺序会影响训练 batch 的数量,建议将 repeat 置于 batch 之后。

设置 batch size

public func batch(batchSize: Int32, dropRemainder: Bool): Unit

参数列表:

名称含义
batchSize类型为 Int32,数据集读取时 mini-batch 的大小
dropRemainder类型为 Bool,数据集数量不是 batchSize 的整数时,是否抛弃最后一组数据。

获取数据集大小

public func getDatasetSize(): Int64

输出:

名称含义
datasetSize数据集样本数,类型为 Int64

设置数据乱序读取

public func shuffle(bufferSize: Int32): Unit

参数列表:

名称含义
bufferSize类型为 Int32,数据集乱序读取的数量,当 bufferSize 等于数据集的行数时,整个数据集读取全部为乱序

设置数据集数据增强操作

public func datasetMap(dsOpHandleList: Array<MSDsOperationHandle>, name: String): Unit

参数列表:

名称含义
dsOpHandleList数据增强操作列表,类型为 MSDsOperationHandle 的数组,具体的操作见数据增强
name数据增强操作所作用的列名,类型为 String

获取下一批数据

public func getNext(params: Array<Tensor>)

参数列表:

名称含义
params网络参数列表,类型为 Tensor 的数组,需要和数据集中保存的数据类别个数相同, 并且 shape 和 dtype 也要相同

子类实现

MnistDataset

public class MnistDataset <: Dataset {
    public init(dataPath: String, sampler!: BuildInSampler = RandomSampler(), epoch!: Int32 = 1)
}

参数列表:

名称含义
dataPath类型为 String,数据集文件路径
sampler类型为 BuildInSampler,采样器,可指定 RandomSampler 随机采样或 SequentialSampler 顺序采样
epoch类型为 Int32,数据集要循环的 epoch 数

Cifar10Dataset

public class Cifar10Dataset <: Dataset {
    public init(dataPath: String, epoch!: Int32 = 1, usage!: String = "all")
}

参数列表:

名称含义
dataPath类型为 String,数据集文件路径
epoch类型为 Int32,数据集要循环的 epoch 数
usage类型为 String,数据集的用法,取值包括 train、test 和 all,默认值为 all

TFRecordDataset

public class TFRecordDataset <: Dataset {
    public init(dataPath: String, schemaPath: String, columnsList: Array<String>, shuffle!: Bool = true, epoch!: Int32 = 1)
}

参数列表:

名称含义
dataPath类型为 String,数据集文件路径
schemaPath类型为 String,TFRecord 的 shema 文件路径
columnsList类型为 String 数组,数据集文件的列名
shuffle类型为 Bool,是否将打乱读取顺序
epoch类型为 Int32,数据集要循环的 epoch 数

MindDataDataset

public class MindDataDataset <: Dataset {
    public init(dataPath: String, columnsList: Array<String>, sampler!: BuildInSampler = RandomSampler(), epoch!: Int32 = 1)
}

参数列表:

名称含义
dataPath类型为 String,数据集文件路径
columnsList类型为 String 数组,数据集文件的列名
sampler类型为 BuildInSampler,采样器,可指定 RandomSampler 随机采样或 SequentialSampler 顺序采样
epoch类型为 Int32,数据集要循环的 epoch 数

ImageFolderDataset

public class ImageFolderDataset <: Dataset {
    public init(dataPath: String, decode: Bool, numShards: Int32, shardId: Int32, shuffle: Bool, epoch!: Int32 = 1)
}

参数列表:

名称含义
dataPath类型为 String,数据集文件路径
decode类型为 Bool,读取后是否对图片进行解码
numShards类型为 Int32,数据分片的数量
shardId类型为 Int32,数据分片的 ID,必须在 numShards 的范围内
shuffle类型为 Bool,是否将打乱读取顺序
epoch类型为 Int32,数据集要循环的 epoch 数

示例代码

如下代码展示了如何读取 MNIST 数据集,并将数据集进行 shuffle 处理,然后将样本两两组成一个批次:

from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*

main(): Int64 {
    let dataPath: String = "./data/mnist/train"
    let ds = MnistDataset(dataPath)
    let batchSize: Int32 = 2
    let bufferSize: Int32 = 1000
    var step: Int64 = 0
    
    ds.shuffle(bufferSize)
    
    ds.batch(batchSize, true)
    
    var rescale = rescale(1.0 / 255.0, 0.0)
    ds.datasetMap([rescale], "image")
    
    var input: Tensor = parameter(zerosTensor([Int64(batchSize), 28, 28], dtype: FLOAT32), "data")
    var label: Tensor = parameter(zerosTensor([Int64(batchSize)], dtype: INT32), "label")
    while (ds.getNext([input, label]) && step < 5) {
        print("---------------\n")
        print("input", input)
        print("label", label)
        step += 1
    }
    return 0
}

数据增强

仓颉 TensorBoost 目前支持的常用数据增强方法如下表所示:

vision

方法说明
randomCropDecodeResize图像裁剪。
randomHorizontalFlip按照指定概率对图像进行水平翻转。
randomColorAdjust调整图像的亮度,对比度,饱和度和色相。
normalize根据均值和方差对图像正态化。
hwc2chw转换图像的数据排列方式。
rescale可以将图像数据映射到固定的范围。
resize图像缩放。
randomCrop图像裁剪。
centerCrop从中心进行裁剪。
randomResizedCrop随机长宽比裁剪。

transforms

函数介绍

randomCropDecodeResize

public func randomCropDecodeResize(size: Array<Int32>, scale: Array<Float32>, ratio: Array<Float32>): MSDsOperationHandle

一个处理 JPEG 图像的高效函数。在随机位置裁剪输入的图像,将裁剪后的图像解码成 RGB 格式,并调整解码图像的尺寸大小。

参数列表:

名称含义
size输出图像的大小。
scale要裁剪的原始范围 [最小值,最大值](默认值 = (0.08, 1.0))。
ratio要裁剪的纵横比范围 [最小值,最大值](默认值 = (3 / 4, 4 / 3))。

randomHorizontalFlip

public func randomHorizontalFlip(prob: Float32): MSDsOperationHandle

对输入图像进行随机水平翻转。

参数说明 参数列表:

randomColorAdjust

public func randomColorAdjust(brightness: Array<Float32>, contrast: Array<Float32>, saturation: Array<Float32>, hue: Array<Float32>): MSDsOperationHandle

调整图像的亮度,对比度,饱和度和色相。

参数列表:

名称含义
brightness亮度。
contrast对比度。
saturation饱和度。
hue色相。

normalize

public func normalize(meanIn: Array<Float32>, stdIn: Array<Float32>): MSDsOperationHandle

根据提供的均值和标准差对输入图像进行归一化。

参数列表:

名称含义
meanIn调整后图像的均值。
stdIn调整后图像的方差。

hwc2chw

public func hwc2chw(): MSDsOperationHandle

数据操作,转换图像的数据排列方式,数据的形状从 高 $\times$ 宽 $\times$ 通道 (HWC) 变为 通道 $\times$ 高 $\times$ 宽 (CHW)。

rescale

public func rescale(rescale: Float32, shift: Float32): MSDsOperationHandle

数据操作,可以将数据映射到固定的范围。

参数列表:

名称含义
rescale比例因子。
shift偏移因子。

resize

public func resize(size: Array<Int32>, interpolation!: Int32 = 0): MSDsOperationHandle

对输入图像进行缩放。

参数列表:

名称含义
size缩放的目标大小。
interpolation缩放时采用的插值方式。

randomCrop

public func randomCrop(size: Array<Int32>, padding: Array<Int32>, padIfNeeded!: Bool = true, fillValue!: Array<UInt8> = [0, 0, 0], paddingMode!: Int32 = 0)

对输入图像随机位置进行指定大小的裁剪。

参数列表:

名称含义
size输出图像的大小。
padding填充的像素数量。
padIfNeeded原图小于裁剪尺寸时,是否需要填充。
fillValue在常量填充模式时使用的填充值。
paddingMode填充模式。

centerCrop

public func centerCrop(size: Array<Int32>): MSDsOperationHandle

对输入图像从中心进行裁剪。

参数列表:

randomResizedCrop

public func randomResizedCrop(size: Array<Int32>, scale: Array<Float32>, ratio: Array<Float32>): MSDsOperationHandle

对输入图像进行随机裁剪,最后将图像调整到设定好的尺寸。

参数列表:

名称含义
size输出图像的大小。
scale要裁剪的原始范围 [最小值,最大值]。
ratio要裁剪的纵横比范围 [最小值,最大值]。

typeCast

public func typeCast(dataType: String): MSDsOperationHandle

数据操作,转换数据类型。

参数列表:

示例代码

以下代码展示了对数据集进行 resizerescalerandomCrop 操作,然后通过 datasetMap 设置数据处理的管道,生成新数据。


from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*

main(): Int64 {
    let dataPath: String = "./data/mnist/train"
    var step: Int64 = 0
    
    let mnistDs = MnistDataset(dataPath)
    
    var resize = resize(Array<Int32>([16, 16]))
    var rescale = rescale(1.0 / 255.0, 0.0)
    var randomCropOp = randomCrop(Array<Int32>([8, 8]), Array<Int32>([0, 0, 0, 0]))
    
    mnistDs.datasetMap([resize, rescale, randomCropOp], "image")
    
    let input = parameter(initialize(Array<Int64>([8, 8]), initType: InitType.ZERO, dtype: FLOAT32), "data")
    let label: Tensor = parameter(zerosTensor(Array<Int64>([1]), dtype: INT32), "label")
    while (mnistDs.getNext([input, label]) && step < 5) {
        print("---------------\n")
        print("input: ", input)
        print("label: ", label)
        step += 1
    }
    return 0
}

采样器

对于数据集对象, 可以通过设置采样器来指定随机还是顺序采样.

采样器类采样器简介
RandomSampler给定数据集对象后,按随机顺序进行采样。
SequentialSampler给定数据集对象后,按顺序从前到后进行采样。

采样器类

RandomSampler

public class RandomSampler <: BuildInSampler {
    public init(replacement!: Bool = false, numSamples!: Int64 = 0)
}

参数列表:

名称含义
replacement类型为 Bool,是否将样本放回。
numSamples类型为 Int64,采样个数。

SequentialSampler

public class SequentialSampler <: BuildInSampler {
    public init(startIndex!: Int64 = 0, numSamples!: Int64 = 0)
}

参数列表:

名称含义
startIndex类型为 Int64,采样起点。
numSamples类型为 Int64,采样个数。

示例代码

以下代码展示了对数据集的前 6 个数据进行采样并打乱顺序读取:

from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*

main(): Int64 {
    var samplerRD = RandomSampler(numSamples: 6)
    let dataPath: String = "./data/mnist/train"
    let ds = MnistDataset(dataPath, sampler: samplerRD)
    
    var rescale = rescale(1.0 / 255.0, 0.0)
    ds.datasetMap([rescale], "image")
    
    var input: Tensor = parameter(zerosTensor(Array<Int64>([1, 28, 28]), dtype: FLOAT32), "data")
    var label: Tensor = parameter(zerosTensor(Array<Int64>([1]), dtype: INT32), "label")
    while (ds.getNext([input, label])) {
        print("---------------\n")
        print("input", input)
        print("label", label)
    }
    return 0
}
Rate this post
发表回复 0