gluonts.torch.model.patch_tst.lightning_module module#
- class gluonts.torch.model.patch_tst.lightning_module.PatchTSTLightningModule(model_kwargs: dict, lr: float = 0.001, weight_decay: float = 1e-08)[源码]#
基类:
lightning.pytorch.core.module.LightningModule
一个可用于使用 PyTorch Lightning 训练
PatchTSTModel
的pl.LightningModule
类。这是围绕一个(包装过的)
PatchTSTModel
对象的薄层,用于暴露评估训练和验证损失的方法。- 参数
model_kwargs – 用于构建要训练的
PatchTSTModel
的关键字参数。loss – 用于训练的损失函数。
lr – 学习率。
weight_decay – 权重衰减正则化参数。