gluonts.torch.model.wavenet.lightning_module 模块#

class gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule(model_kwargs: dict, lr: float = 0.001, weight_decay: float = 1e-08)[source]#

基类: lightning.pytorch.core.module.LightningModule

WaveNet 的 LightningModule 封装器。

参数
  • model_kwargs – 传递给 WaveNet 的关键字参数。

  • lr – 学习率,默认为 1e-3

  • optional – 学习率,默认为 1e-3

  • weight_decay – 权重衰减,默认为 1e-8

  • optional – 权重衰减,默认为 1e-8

configure_optimizers()[source]#

返回要使用的优化器。

forward(*args, **kwargs)[source]#

torch.nn.Module.forward() 相同。

参数
  • *args – 你决定传递给 forward 方法的任何参数。

  • **kwargs – 关键字参数也是可能的。

返回

你的模型的输出

training_step(batch, batch_idx: int)[source]#

执行训练步骤。

validation_step(batch, batch_idx: int)[source]#

执行验证步骤。