gluonts.torch.model.i_transformer.lightning_module 模块#

class gluonts.torch.model.i_transformer.lightning_module.ITransformerLightningModule(model_kwargs: dict, num_parallel_samples: int = 100, lr: float = 0.001, weight_decay: float = 1e-08)[source]#

一个 pl.LightningModule 类,可用于使用 PyTorch Lightning 训练 ITransformerModel

这是 ITransformerModel 对象(已封装)的一个薄层,暴露了评估训练和验证损失的方法。

参数

model_kwargs – 用于构建要训练的 ITransformerModel 的关键字参数。
  • num_parallel_samples – 推断期间每个时间序列要采样的评估样本数。

  • lr – 学习率。

  • weight_decay – 权重衰减正则化参数。

  • configure_optimizers()[source]#

返回要使用的优化器。

forward(*args, **kwargs)[source]#

torch.nn.Module.forward() 相同。

*args – 决定传入 forward 方法的任意参数。

model_kwargs – 用于构建要训练的 ITransformerModel 的关键字参数。
  • **kwargs – 关键字参数也是可以的。

  • 返回

模型的输出

training_step(batch, batch_idx: int)[source]#

执行训练步骤。

validation_step(batch, batch_idx: int)[source]#

执行验证步骤。

上一页
gluonts.torch.model.i_transformer.estimator 模块