gluonts.torch.model.patch_tst.lightning_module module#

class gluonts.torch.model.patch_tst.lightning_module.PatchTSTLightningModule(model_kwargs: dict, lr: float = 0.001, weight_decay: float = 1e-08)[源码]#

基类: lightning.pytorch.core.module.LightningModule

一个可用于使用 PyTorch Lightning 训练 PatchTSTModelpl.LightningModule 类。

这是围绕一个(包装过的)PatchTSTModel 对象的薄层,用于暴露评估训练和验证损失的方法。

参数
  • model_kwargs – 用于构建要训练的 PatchTSTModel 的关键字参数。

  • loss – 用于训练的损失函数。

  • lr – 学习率。

  • weight_decay – 权重衰减正则化参数。

configure_optimizers()[源码]#

返回要使用的优化器。

forward(*args, **kwargs)[源码]#

torch.nn.Module.forward() 相同。

参数
  • *args – 您决定传递给 forward 方法的任何参数。

  • **kwargs – 也支持关键字参数。

返回

您的模型的输出

training_step(batch, batch_idx: int)[源码]#

执行训练步骤。

validation_step(batch, batch_idx: int)[源码]#

执行验证步骤。