gluonts.dataset.loader 模块#

class gluonts.dataset.loader.Batch(*, batch_size: int)[source]#

基类:gluonts.transform._base.Transformation, pydantic.v1.main.BaseModel

batch_size: int#
gluonts.dataset.loader.InferenceDataLoader(dataset: gluonts.dataset.Dataset, *, transform: gluonts.transform._base.Transformation = <gluonts.transform._base.Identity object>, batch_size: int, stack_fn: typing.Callable)[source]#

构建用于推断目的的批处理迭代器。

参数
  • dataset – 要迭代的数据。

  • transform – 在数据迭代时延迟应用的变换。此变换在“推断模式”下应用(is_train=False)。

  • batch_size – 每个批处理中包含的条目数。

  • stack_fn – 用于将数据条目堆叠到批处理中的函数。这可用于设置数组应位于的特定数组类型或计算设备(CPU、GPU)。

返回

一个可迭代的批处理序列。

返回类型

Iterable[DataBatch]

class gluonts.dataset.loader.Stack[source]#

基类:gluonts.transform._base.Transformation, pydantic.v1.main.BaseModel

gluonts.dataset.loader.TrainDataLoader(dataset: gluonts.dataset.Dataset, *, transform: gluonts.transform._base.Transformation = <gluonts.transform._base.Identity object>, batch_size: int, stack_fn: typing.Callable, num_batches_per_epoch: typing.Optional[int] = None, shuffle_buffer_length: typing.Optional[int] = None)[source]#

构建用于训练目的的批处理迭代器。

此函数包装 DataLoader 以提供特定于训练的行为和选项,如下所示

1. 提供的数集是循环迭代的,以便可以在单个 epoch 中多次遍历它。2. 必须提供一个变换,该变换在数据集迭代时延迟应用;这对于例如从数据集中的每个时间序列中切出固定长度的随机实例非常有用。3. 生成的批处理可以按伪随机顺序迭代。

返回的对象是一个有状态的迭代器,其长度要么是 num_batches_per_epoch(如果不为 None),要么是无限的(否则)。

参数
  • dataset – 要迭代的数据。

  • transform – 在数据迭代时延迟应用的变换。此变换在“训练模式”下应用(is_train=True)。

  • batch_size – 每个批处理中包含的条目数。

  • stack_fn – 用于将数据条目堆叠到批处理中的函数。这可用于设置数组应位于的特定数组类型或计算设备(CPU、GPU)。

  • num_batches_per_epoch – 迭代器的长度。如果为 None,则迭代器是无限的。

  • shuffle_buffer_length – 用于洗牌的缓冲区大小。默认值:None,在这种情况下不进行洗牌。

返回

一个批处理迭代器。

返回类型

Iterator[DataBatch]

gluonts.dataset.loader.ValidationDataLoader(dataset: gluonts.dataset.Dataset, *, transform: gluonts.transform._base.Transformation = <gluonts.transform._base.Identity object>, batch_size: int, stack_fn: typing.Callable)[source]#

构建用于验证目的的批处理迭代器。

参数
  • dataset – 要迭代的数据。

  • transform – 在数据迭代时延迟应用的变换。此变换在“训练模式”下应用(is_train=True)。

  • batch_size – 每个批处理中包含的条目数。

  • stack_fn – 用于将数据条目堆叠到批处理中的函数。这可用于设置数组应位于的特定数组类型或计算设备(CPU、GPU)。

返回

一个可迭代的批处理序列。

返回类型

Iterable[DataBatch]

gluonts.dataset.loader.as_stacked_batches(dataset: gluonts.dataset.Dataset, *, batch_size: int, output_type: Optional[Callable] = None, num_batches_per_epoch: Optional[int] = None, shuffle_buffer_length: Optional[int] = None, field_names: Optional[list] = None)[source]#

准备以批处理形式传递给网络的数据。

输入数据被收集到大小为 batch_size 的批处理中,然后列堆叠在一起。此外,如果提供了 output_type,结果会包裹在 output_type 中。

如果提供了 num_batches_per_epoch,则只会返回指定数量的批处理。这在提供循环数据集进行训练时特别有用。

为了对数据进行伪随机洗牌,可以设置 shuffle_buffer_length 以先将输入收集到缓冲区中,然后从中随机采样。

设置 field_names 将只考虑输入数据中的这些列,并丢弃所有其他值。