gluonts.torch.model.wavenet.estimator 模块#
- class gluonts.torch.model.wavenet.estimator.WaveNetEstimator(freq: str, prediction_length: int, num_bins: int =1024, num_residual_channels: int =24, num_skip_channels: int =32, dilation_depth: Optional[int] =None, num_stacks: int =1, temperature: float =1.0, num_feat_dynamic_real: int =0, num_feat_static_cat: int =0, num_feat_static_real: int =0, cardinality: List[int] =[1], seasonality: Optional[int] =None, embedding_dimension: int =5, use_log_scale_feature: bool =True, time_features: Optional[List[Callable[[pandas.core.indexes.period.PeriodIndex], numpy.ndarray]]] =None, lr: float =0.001, weight_decay: float =1e-08, train_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, validation_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, batch_size: int =32, num_batches_per_epoch: int =50, num_parallel_samples: int =100, negative_data: bool =False, trainer_kwargs: Optional[Dict[str, Any]] =None)[source]#
基类:
gluonts.torch.model.estimator.PyTorchLightningEstimator
- create_lightning_module() lightning.pytorch.core.module.LightningModule [source]#
创建并返回用于训练(即计算损失)的网络。
- 返回值
根据输入数据计算损失的网络。
- 返回值类型
pl.LightningModule
- create_predictor(transformation: gluonts.transform._base.Transformation, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule) gluonts.torch.model.predictor.PyTorchPredictor [source]#
创建并返回一个预测器对象。
- 参数
transformation – 应用于数据进入模型之前的转换。
module – 一个训练好的 pl.LightningModule 对象。
- 返回值
一个用于推理的包裹 nn.Module 的预测器。
- 返回值类型
- create_training_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, shuffle_buffer_length: Optional[int] =None, **kwargs) Iterable [source]#
创建用于训练的数据加载器。
- 参数
data – 用于创建数据加载器的数据集。
module – 将接收数据加载器批量数据的 pl.LightningModule 对象。
- 返回值
数据加载器,即数据批次的迭代器。
- 返回值类型
可迭代对象
- create_transformation() gluonts.transform._base.Transformation [source]#
创建并返回训练和推理所需的转换。
- 返回值
将在训练和推理时对数据集逐条应用的转换。
- 返回值类型
- create_validation_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, **kwargs) Iterable [source]#
创建用于验证的数据加载器。
- 参数
data – 用于创建数据加载器的数据集。
module – 将接收数据加载器批量数据的 pl.LightningModule 对象。
- 返回值
数据加载器,即数据批次的迭代器。
- 返回值类型
可迭代对象
- lead_time: int#
- prediction_length: int#