gluonts.torch.model.wavenet.estimator 模块#
- class gluonts.torch.model.wavenet.estimator.WaveNetEstimator(freq: str, prediction_length: int, num_bins: int =1024, num_residual_channels: int =24, num_skip_channels: int =32, dilation_depth: Optional[int] =None, num_stacks: int =1, temperature: float =1.0, num_feat_dynamic_real: int =0, num_feat_static_cat: int =0, num_feat_static_real: int =0, cardinality: List[int] =[1], seasonality: Optional[int] =None, embedding_dimension: int =5, use_log_scale_feature: bool =True, time_features: Optional[List[Callable[[pandas.core.indexes.period.PeriodIndex], numpy.ndarray]]] =None, lr: float =0.001, weight_decay: float =1e-08, train_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, validation_sampler: Optional[gluonts.transform.sampler.InstanceSampler] =None, batch_size: int =32, num_batches_per_epoch: int =50, num_parallel_samples: int =100, negative_data: bool =False, trainer_kwargs: Optional[Dict[str, Any]] =None)[source]#
- 基类: - gluonts.torch.model.estimator.PyTorchLightningEstimator- create_lightning_module() lightning.pytorch.core.module.LightningModule[source]#
- 创建并返回用于训练(即计算损失)的网络。 - 返回值
- 根据输入数据计算损失的网络。 
- 返回值类型
- pl.LightningModule 
 
 - create_predictor(transformation: gluonts.transform._base.Transformation, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule) gluonts.torch.model.predictor.PyTorchPredictor[source]#
- 创建并返回一个预测器对象。 - 参数
- transformation – 应用于数据进入模型之前的转换。 
- module – 一个训练好的 pl.LightningModule 对象。 
 
- 返回值
- 一个用于推理的包裹 nn.Module 的预测器。 
- 返回值类型
 
 - create_training_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, shuffle_buffer_length: Optional[int] =None, **kwargs) Iterable[source]#
- 创建用于训练的数据加载器。 - 参数
- data – 用于创建数据加载器的数据集。 
- module – 将接收数据加载器批量数据的 pl.LightningModule 对象。 
 
- 返回值
- 数据加载器,即数据批次的迭代器。 
- 返回值类型
- 可迭代对象 
 
 - create_transformation() gluonts.transform._base.Transformation[source]#
- 创建并返回训练和推理所需的转换。 - 返回值
- 将在训练和推理时对数据集逐条应用的转换。 
- 返回值类型
 
 - create_validation_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule, **kwargs) Iterable[source]#
- 创建用于验证的数据加载器。 - 参数
- data – 用于创建数据加载器的数据集。 
- module – 将接收数据加载器批量数据的 pl.LightningModule 对象。 
 
- 返回值
- 数据加载器,即数据批次的迭代器。 
- 返回值类型
- 可迭代对象 
 
 - lead_time: int#
 - prediction_length: int#