gluonts.torch.model.wavenet.estimator 模块#

class gluonts.torch.model.wavenet.estimator.WaveNetEstimator(freq: str, prediction_length: int, num_bins: int =1024, num_residual_channels: int =24, num_skip_channels: int =32, dilation_depth: Optional[int] =None, num_stacks: int =1, temperature: float =1.0, num_feat_dynamic_real: int =0, num_feat_static_cat: int =0, num_feat_static_real: int =0, cardinality: List[int] =[1], seasonality: Optional[int] =None, embedding_dimension: int =5, use_log_scale_feature: bool =True, time_features: Optional[List[Callable[[pandas.core.indexes.period.PeriodIndex], numpy.ndarray]]] =None, lr: float =0.001, weight_decay: float =1e-08, train_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, validation_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, batch_size: int =32, num_batches_per_epoch: int =50, num_parallel_samples: int =100, negative_data: bool =False, trainer_kwargs: Optional[Dict[str, Any]] =None)[source]#

基类: gluonts.torch.model.estimator.PyTorchLightningEstimator

create_lightning_module() lightning.pytorch.core.module.LightningModule[source]#

创建并返回用于训练(即计算损失)的网络。

返回值

根据输入数据计算损失的网络。

返回值类型

pl.LightningModule

create_predictor(transformation: gluonts.transform._base.Transformation, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule) gluonts.torch.model.predictor.PyTorchPredictor[source]#

创建并返回一个预测器对象。

参数
  • transformation – 应用于数据进入模型之前的转换。

  • module – 一个训练好的 pl.LightningModule 对象。

返回值

一个用于推理的包裹 nn.Module 的预测器。

返回值类型

预测器

create_training_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, shuffle_buffer_length: Optional[int] =None, **kwargs) Iterable[source]#

创建用于训练的数据加载器。

参数
  • data – 用于创建数据加载器的数据集。

  • module – 将接收数据加载器批量数据的 pl.LightningModule 对象。

返回值

数据加载器,即数据批次的迭代器。

返回值类型

可迭代对象

create_transformation() gluonts.transform._base.Transformation[source]#

创建并返回训练和推理所需的转换。

返回值

将在训练和推理时对数据集逐条应用的转换。

返回值类型

转换

create_validation_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, **kwargs) Iterable[source]#

创建用于验证的数据加载器。

参数
  • data – 用于创建数据加载器的数据集。

  • module – 将接收数据加载器批量数据的 pl.LightningModule 对象。

返回值

数据加载器,即数据批次的迭代器。

返回值类型

可迭代对象

lead_time: int#
prediction_length: int#