gluonts.torch.model.tft.lightning_module 模块#
- class gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule(model_kwargs: dict, lr: float = 0.001, patience: int =10, weight_decay: float = 0.0)[source]#
- 基类: - lightning.pytorch.core.module.LightningModule- 一个 - pl.LightningModule类,可用于使用 PyTorch Lightning 训练- TemporalFusionTransformerModel。- 这是 (封装的) - TemporalFusionTransformerModel对象的一个薄层,它暴露了评估训练和验证损失的方法。- 参数
- model_kwargs – 用于构建要训练的 - TemporalFusionTransformerModel的关键字参数。
- lr – 学习率。 
- weight_decay – 权重衰减正则化参数。 
- patience – 学习率调度器的耐心参数。