gluonts.mx.model.estimator 模块#

class gluonts.mx.model.estimator.GluonEstimator(*, trainer: gluonts.mx.trainer._base.Trainer, batch_size: int = 32, lead_time: int = 0, dtype: typing.Type = <class 'numpy.float32'>)[source]#

基类: gluonts.model.estimator.Estimator

一个 Estimator 类型,包含用于创建基于 Gluon 的模型的工具。

要扩展此类,需要实现五个方法:create_transformationcreate_training_networkcreate_predictorcreate_training_data_loadercreate_validation_data_loader

create_predictor(transformation: gluonts.transform._base.Transformation, trained_network: mxnet.gluon.block.HybridBlock) gluonts.model.predictor.Predictor[source]#

创建并返回一个预测器对象。

参数
  • transformation – 应用于数据在进入模型之前的转换。

  • module – 一个已训练的 HybridBlock 对象。

返回

一个包装了用于推理的 HybridBlock 的预测器。

返回类型

Predictor

create_training_data_loader(data: gluonts.dataset.Dataset, **kwargs) Iterable[Dict[str, Any]][source]#

创建一个用于训练的数据加载器。

参数

data – 用于创建数据加载器的数据集。

返回

数据加载器,即数据批次的迭代器。

返回类型

DataLoader

create_training_network() mxnet.gluon.block.HybridBlock[source]#

创建并返回用于训练的网络(即,计算损失)。

返回

给定输入数据计算损失的网络。

返回类型

HybridBlock

create_transformation() gluonts.transform._base.Transformation[source]#

创建并返回训练和推理所需的转换。

返回

将在训练和推理时逐条应用于数据集的转换。

返回类型

Transformation

create_validation_data_loader(data: gluonts.dataset.Dataset, **kwargs) Iterable[Dict[str, Any]][source]#

创建一个用于验证的数据加载器。

参数

data – 用于创建数据加载器的数据集。

返回

数据加载器,即数据批次的迭代器。

返回类型

DataLoader

classmethod from_hyperparameters(**hyperparameters) gluonts.mx.model.estimator.GluonEstimator[source]#
lead_time: int#
prediction_length: int#
train(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, **kwargs) gluonts.model.predictor.Predictor[source]#

在给定数据上训练估计器。

参数
  • training_data – 用于训练模型的数据集。

  • validation_data – 训练期间用于验证模型的数据集。

返回

包含已训练模型的预测器。

返回类型

Predictor

train_from(predictor: gluonts.mx.model.predictor.GluonPredictor, training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False) gluonts.model.predictor.Predictor[source]#
train_model(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, from_predictor: Optional[gluonts.mx.model.predictor.GluonPredictor] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False) gluonts.mx.model.estimator.TrainOutput[source]#
class gluonts.mx.model.estimator.TrainOutput(transformation, trained_net, predictor)[source]#

基类: tuple

predictor: gluonts.model.predictor.Predictor#

字段编号 2 的别名

trained_net: mxnet.gluon.block.HybridBlock#

字段编号 1 的别名

transformation: gluonts.transform._base.Transformation#

字段编号 0 的别名