gluonts.torch.model.estimator 模块#

class gluonts.torch.model.estimator.PyTorchLightningEstimator(trainer_kwargs: Dict[str, Any], lead_time: int = 0)[source]#

基类: gluonts.model.estimator.Estimator

一个提供用于创建基于 PyTorch-Lightning 的模型工具的 Estimator 类型。

要扩展此类,需要实现以下方法:create_transformationcreate_training_networkcreate_predictorcreate_training_data_loadercreate_validation_data_loader

create_lightning_module() lightning.pytorch.core.module.LightningModule[source]#

创建并返回用于训练(即计算损失)的网络。

返回

根据输入数据计算损失的网络。

返回类型

pl.LightningModule

create_predictor(transformation: gluonts.transform._base.Transformation, module) gluonts.torch.model.predictor.PyTorchPredictor[source]#

创建并返回一个预测器对象。

参数
  • transformation – 在数据进入模型之前要应用的转换。

  • module – 一个已训练的 pl.LightningModule 对象。

返回

一个封装了用于推断的 nn.Module 的预测器。

返回类型

预测器

create_training_data_loader(data: gluonts.dataset.Dataset, module, **kwargs) Iterable[source]#

创建一个用于训练目的的数据加载器。

参数
  • data – 用于创建数据加载器的数据集。

  • module – 将接收来自数据加载器的批次数据的 pl.LightningModule 对象。

返回

数据加载器,即一个可迭代对象,用于遍历数据批次。

返回类型

Iterable

create_transformation() gluonts.transform._base.Transformation[source]#

创建并返回训练和推断所需的转换。

返回

将在训练和推断时,按条目应用于数据集的转换。

返回类型

转换

create_validation_data_loader(data: gluonts.dataset.Dataset, module, **kwargs) Iterable[source]#

创建一个用于验证目的的数据加载器。

参数
  • data – 用于创建数据加载器的数据集。

  • module – 将接收来自数据加载器的批次数据的 pl.LightningModule 对象。

返回

数据加载器,即一个可迭代对象,用于遍历数据批次。

返回类型

Iterable

lead_time: int#
prediction_length: int#
train(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) gluonts.torch.model.predictor.PyTorchPredictor[source]#

使用给定的数据训练此估计器。

参数
  • training_data – 用于训练模型的数据集。

  • validation_data – 在训练期间用于验证模型的数据集。

返回

包含已训练模型的预测器。

返回类型

预测器

train_from(predictor: gluonts.model.predictor.Predictor, training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None) gluonts.torch.model.predictor.PyTorchPredictor[source]#
train_model(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, from_predictor: Optional[gluonts.torch.model.predictor.PyTorchPredictor] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) gluonts.torch.model.estimator.TrainOutput[source]#
class gluonts.torch.model.estimator.TrainOutput(transformation, trained_net, trainer, predictor)[source]#

基类: tuple

predictor: gluonts.torch.model.predictor.PyTorchPredictor#

字段编号 3 的别名

trained_net: torch.nn.modules.module.Module#

字段编号 1 的别名

trainer: lightning.pytorch.trainer.trainer.Trainer#

字段编号 2 的别名

transformation: gluonts.transform._base.Transformation#

字段编号 0 的别名