5.9.3. Dataset API
5.9.3. Dataset API
仓颉 TensorBoost 支持加载 AI 领域常用的数据集,用户可以直接调用 dataset 包中的类实现数据集的加载。目前支持的数据集如下表所示:
数据集 | 数据集类 | 数据集简介 |
---|---|---|
MNIST | MnistDataset | MNIST 是一个大型手写数字图像数据集,拥有 60,000 张训练图像和 10,000 张测试图像,常用于训练各种图像处理系统。 |
CIFAR-10 | Cifar10Dataset | CIFAR-10 是一个微小图像数据集,包含 10 种类别下的 60,000 张 32×32 大小彩色图像,平均每种类别 6,000 张,其中 5,000 张为训练集,1,000 张为测试集。 |
仓颉 TensorBoost 也支持加载多种存储格式下的数据集,用户可以直接调用 dataset 包中的类实现数据集的加载,目前支持的数据格式如下表所示:
数据集 | 数据集类 | 数据集简介 |
---|---|---|
TFRecord | TFRecordDataset | TFRecord 是 TensorFlow 定义的一种二进制数据文件格式。 |
MindRecord | MindDataDataset | MindRecord 是 mindspore 定义的一种二进制数据文件格式。 |
ImageFolder | ImageFolderDataset | ImageFolder 是图片文件格式。 |
数据集
Dataset 基类
仓颉 TensorBoost 提供数据加载与处理的基类 Dataset,Dataset 提供了统一的数据操作与管理方式,不同格式的数据集均继承自 Dataset。直接创建 Dataset 的对象是无意义的,加载数据集需要创建 Dataset 相应的子类。
public open class Dataset {
public init()
public init(epoch: Int32, columnNames: Array<String>)
public func repeat(count: Int32)
public func batch(batchSize: Int32, dropRemainder: Bool)
public func shuffle(bufferSize: Int32)
public func datasetMap(dsOpHandleList: Array<MSDsOperationHandle>, name: String)
public func getDatasetSize()
public func getNext(params: Array<Tensor>)
}
初始化参数列表:
名称 | 含义 |
---|---|
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
columnNames | 类型为 String 数组,数据集文件的列名 |
设置重复次数
public func repeat(count: Int32): Unit
设置数据集重复的次数,达到扩充数据量的目的。
参数列表:
名称 | 含义 |
---|---|
count | 类型为 Int32,数据集重复的次数 |
【注意】 repeat 和 batch 操作的顺序会影响训练 batch 的数量,建议将 repeat 置于 batch 之后。
设置 batch size
public func batch(batchSize: Int32, dropRemainder: Bool): Unit
参数列表:
名称 | 含义 |
---|---|
batchSize | 类型为 Int32,数据集读取时 mini-batch 的大小 |
dropRemainder | 类型为 Bool,数据集数量不是 batchSize 的整数时,是否抛弃最后一组数据。 |
获取数据集大小
public func getDatasetSize(): Int64
输出:
名称 | 含义 |
---|---|
datasetSize | 数据集样本数,类型为 Int64 |
设置数据乱序读取
public func shuffle(bufferSize: Int32): Unit
参数列表:
名称 | 含义 |
---|---|
bufferSize | 类型为 Int32,数据集乱序读取的数量,当 bufferSize 等于数据集的行数时,整个数据集读取全部为乱序 |
设置数据集数据增强操作
public func datasetMap(dsOpHandleList: Array<MSDsOperationHandle>, name: String): Unit
参数列表:
名称 | 含义 |
---|---|
dsOpHandleList | 数据增强操作列表,类型为 MSDsOperationHandle 的数组,具体的操作见数据增强 |
name | 数据增强操作所作用的列名,类型为 String |
获取下一批数据
public func getNext(params: Array<Tensor>)
参数列表:
名称 | 含义 |
---|---|
params | 网络参数列表,类型为 Tensor 的数组,需要和数据集中保存的数据类别个数相同, 并且 shape 和 dtype 也要相同 |
子类实现
MnistDataset
public class MnistDataset <: Dataset {
public init(dataPath: String, sampler!: BuildInSampler = RandomSampler(), epoch!: Int32 = 1)
}
参数列表:
名称 | 含义 |
---|---|
dataPath | 类型为 String,数据集文件路径 |
sampler | 类型为 BuildInSampler,采样器,可指定 RandomSampler 随机采样或 SequentialSampler 顺序采样 |
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
Cifar10Dataset
public class Cifar10Dataset <: Dataset {
public init(dataPath: String, epoch!: Int32 = 1, usage!: String = "all")
}
参数列表:
名称 | 含义 |
---|---|
dataPath | 类型为 String,数据集文件路径 |
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
usage | 类型为 String,数据集的用法,取值包括 train、test 和 all,默认值为 all |
TFRecordDataset
public class TFRecordDataset <: Dataset {
public init(dataPath: String, schemaPath: String, columnsList: Array<String>, shuffle!: Bool = true, epoch!: Int32 = 1)
}
参数列表:
名称 | 含义 |
---|---|
dataPath | 类型为 String,数据集文件路径 |
schemaPath | 类型为 String,TFRecord 的 shema 文件路径 |
columnsList | 类型为 String 数组,数据集文件的列名 |
shuffle | 类型为 Bool,是否将打乱读取顺序 |
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
MindDataDataset
public class MindDataDataset <: Dataset {
public init(dataPath: String, columnsList: Array<String>, sampler!: BuildInSampler = RandomSampler(), epoch!: Int32 = 1)
}
参数列表:
名称 | 含义 |
---|---|
dataPath | 类型为 String,数据集文件路径 |
columnsList | 类型为 String 数组,数据集文件的列名 |
sampler | 类型为 BuildInSampler,采样器,可指定 RandomSampler 随机采样或 SequentialSampler 顺序采样 |
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
ImageFolderDataset
public class ImageFolderDataset <: Dataset {
public init(dataPath: String, decode: Bool, numShards: Int32, shardId: Int32, shuffle: Bool, epoch!: Int32 = 1)
}
参数列表:
名称 | 含义 |
---|---|
dataPath | 类型为 String,数据集文件路径 |
decode | 类型为 Bool,读取后是否对图片进行解码 |
numShards | 类型为 Int32,数据分片的数量 |
shardId | 类型为 Int32,数据分片的 ID,必须在 numShards 的范围内 |
shuffle | 类型为 Bool,是否将打乱读取顺序 |
epoch | 类型为 Int32,数据集要循环的 epoch 数 |
示例代码
如下代码展示了如何读取 MNIST 数据集,并将数据集进行 shuffle 处理,然后将样本两两组成一个批次:
from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*
main(): Int64 {
let dataPath: String = "./data/mnist/train"
let ds = MnistDataset(dataPath)
let batchSize: Int32 = 2
let bufferSize: Int32 = 1000
var step: Int64 = 0
ds.shuffle(bufferSize)
ds.batch(batchSize, true)
var rescale = rescale(1.0 / 255.0, 0.0)
ds.datasetMap([rescale], "image")
var input: Tensor = parameter(zerosTensor([Int64(batchSize), 28, 28], dtype: FLOAT32), "data")
var label: Tensor = parameter(zerosTensor([Int64(batchSize)], dtype: INT32), "label")
while (ds.getNext([input, label]) && step < 5) {
print("---------------\n")
print("input", input)
print("label", label)
step += 1
}
return 0
}
数据增强
仓颉 TensorBoost 目前支持的常用数据增强方法如下表所示:
vision
方法 | 说明 |
---|---|
randomCropDecodeResize | 图像裁剪。 |
randomHorizontalFlip | 按照指定概率对图像进行水平翻转。 |
randomColorAdjust | 调整图像的亮度,对比度,饱和度和色相。 |
normalize | 根据均值和方差对图像正态化。 |
hwc2chw | 转换图像的数据排列方式。 |
rescale | 可以将图像数据映射到固定的范围。 |
resize | 图像缩放。 |
randomCrop | 图像裁剪。 |
centerCrop | 从中心进行裁剪。 |
randomResizedCrop | 随机长宽比裁剪。 |
transforms
函数介绍
randomCropDecodeResize
public func randomCropDecodeResize(size: Array<Int32>, scale: Array<Float32>, ratio: Array<Float32>): MSDsOperationHandle
一个处理 JPEG 图像的高效函数。在随机位置裁剪输入的图像,将裁剪后的图像解码成 RGB 格式,并调整解码图像的尺寸大小。
参数列表:
名称 | 含义 |
---|---|
size | 输出图像的大小。 |
scale | 要裁剪的原始范围 [最小值,最大值](默认值 = (0.08, 1.0))。 |
ratio | 要裁剪的纵横比范围 [最小值,最大值](默认值 = (3 / 4, 4 / 3))。 |
randomHorizontalFlip
public func randomHorizontalFlip(prob: Float32): MSDsOperationHandle
对输入图像进行随机水平翻转。
参数说明 参数列表:
randomColorAdjust
public func randomColorAdjust(brightness: Array<Float32>, contrast: Array<Float32>, saturation: Array<Float32>, hue: Array<Float32>): MSDsOperationHandle
调整图像的亮度,对比度,饱和度和色相。
参数列表:
名称 | 含义 |
---|---|
brightness | 亮度。 |
contrast | 对比度。 |
saturation | 饱和度。 |
hue | 色相。 |
normalize
public func normalize(meanIn: Array<Float32>, stdIn: Array<Float32>): MSDsOperationHandle
根据提供的均值和标准差对输入图像进行归一化。
参数列表:
名称 | 含义 |
---|---|
meanIn | 调整后图像的均值。 |
stdIn | 调整后图像的方差。 |
hwc2chw
public func hwc2chw(): MSDsOperationHandle
数据操作,转换图像的数据排列方式,数据的形状从 高 $\times$ 宽 $\times$ 通道 (HWC) 变为 通道 $\times$ 高 $\times$ 宽 (CHW)。
rescale
public func rescale(rescale: Float32, shift: Float32): MSDsOperationHandle
数据操作,可以将数据映射到固定的范围。
参数列表:
名称 | 含义 |
---|---|
rescale | 比例因子。 |
shift | 偏移因子。 |
resize
public func resize(size: Array<Int32>, interpolation!: Int32 = 0): MSDsOperationHandle
对输入图像进行缩放。
参数列表:
名称 | 含义 |
---|---|
size | 缩放的目标大小。 |
interpolation | 缩放时采用的插值方式。 |
randomCrop
public func randomCrop(size: Array<Int32>, padding: Array<Int32>, padIfNeeded!: Bool = true, fillValue!: Array<UInt8> = [0, 0, 0], paddingMode!: Int32 = 0)
对输入图像随机位置进行指定大小的裁剪。
参数列表:
名称 | 含义 |
---|---|
size | 输出图像的大小。 |
padding | 填充的像素数量。 |
padIfNeeded | 原图小于裁剪尺寸时,是否需要填充。 |
fillValue | 在常量填充模式时使用的填充值。 |
paddingMode | 填充模式。 |
centerCrop
public func centerCrop(size: Array<Int32>): MSDsOperationHandle
对输入图像从中心进行裁剪。
参数列表:
randomResizedCrop
public func randomResizedCrop(size: Array<Int32>, scale: Array<Float32>, ratio: Array<Float32>): MSDsOperationHandle
对输入图像进行随机裁剪,最后将图像调整到设定好的尺寸。
参数列表:
名称 | 含义 |
---|---|
size | 输出图像的大小。 |
scale | 要裁剪的原始范围 [最小值,最大值]。 |
ratio | 要裁剪的纵横比范围 [最小值,最大值]。 |
typeCast
public func typeCast(dataType: String): MSDsOperationHandle
数据操作,转换数据类型。
参数列表:
示例代码
以下代码展示了对数据集进行 resize
,rescale
和 randomCrop
操作,然后通过 datasetMap
设置数据处理的管道,生成新数据。
from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*
main(): Int64 {
let dataPath: String = "./data/mnist/train"
var step: Int64 = 0
let mnistDs = MnistDataset(dataPath)
var resize = resize(Array<Int32>([16, 16]))
var rescale = rescale(1.0 / 255.0, 0.0)
var randomCropOp = randomCrop(Array<Int32>([8, 8]), Array<Int32>([0, 0, 0, 0]))
mnistDs.datasetMap([resize, rescale, randomCropOp], "image")
let input = parameter(initialize(Array<Int64>([8, 8]), initType: InitType.ZERO, dtype: FLOAT32), "data")
let label: Tensor = parameter(zerosTensor(Array<Int64>([1]), dtype: INT32), "label")
while (mnistDs.getNext([input, label]) && step < 5) {
print("---------------\n")
print("input: ", input)
print("label: ", label)
step += 1
}
return 0
}
采样器
对于数据集对象, 可以通过设置采样器来指定随机还是顺序采样.
采样器类 | 采样器简介 |
---|---|
RandomSampler | 给定数据集对象后,按随机顺序进行采样。 |
SequentialSampler | 给定数据集对象后,按顺序从前到后进行采样。 |
采样器类
RandomSampler
public class RandomSampler <: BuildInSampler {
public init(replacement!: Bool = false, numSamples!: Int64 = 0)
}
参数列表:
名称 | 含义 |
---|---|
replacement | 类型为 Bool,是否将样本放回。 |
numSamples | 类型为 Int64,采样个数。 |
SequentialSampler
public class SequentialSampler <: BuildInSampler {
public init(startIndex!: Int64 = 0, numSamples!: Int64 = 0)
}
参数列表:
名称 | 含义 |
---|---|
startIndex | 类型为 Int64,采样起点。 |
numSamples | 类型为 Int64,采样个数。 |
示例代码
以下代码展示了对数据集的前 6 个数据进行采样并打乱顺序读取:
from CangjieTB import dataset.*
from CangjieTB import common.*
from CangjieTB import ops.*
main(): Int64 {
var samplerRD = RandomSampler(numSamples: 6)
let dataPath: String = "./data/mnist/train"
let ds = MnistDataset(dataPath, sampler: samplerRD)
var rescale = rescale(1.0 / 255.0, 0.0)
ds.datasetMap([rescale], "image")
var input: Tensor = parameter(zerosTensor(Array<Int64>([1, 28, 28]), dtype: FLOAT32), "data")
var label: Tensor = parameter(zerosTensor(Array<Int64>([1]), dtype: INT32), "label")
while (ds.getNext([input, label])) {
print("---------------\n")
print("input", input)
print("label", label)
}
return 0
}