gluonts.torch.model.tft.module 模块#
- class gluonts.torch.model.tft.module.TemporalFusionTransformerModel(context_length: int, prediction_length: int, d_feat_static_real: Optional[List[int]] = None, c_feat_static_cat: Optional[List[int]] = None, d_feat_dynamic_real: Optional[List[int]] = None, c_feat_dynamic_cat: Optional[List[int]] = None, d_past_feat_dynamic_real: Optional[List[int]] = None, c_past_feat_dynamic_cat: Optional[List[int]] = None, num_heads: int = 4, d_hidden: int = 32, d_var: int = 32, dropout_rate: float = 0.1, distr_output: Optional[gluonts.torch.distributions.output.Output] = None)[source]#
基类:
torch.nn.modules.module.Module
时序融合Transformer神经网络。
部分基于 github.com/kashif/pytorch-transformer-ts 中的实现。
输入 feat_static_real, feat_static_cat 和 feat_dynamic_real 是必需的。输入 feat_dynamic_cat, past_feat_dynamic_real 和 past_feat_dynamic_cat 是可选的。
- describe_inputs(batch_size=1) gluonts.model.inputs.InputSpec [source]#
- feat_dynamic_embed: Optional[FeatureEmbedder]#
- feat_dynamic_proj: Optional[FeatureProjector]#
- feat_static_embed: Optional[FeatureEmbedder]#
- feat_static_proj: Optional[FeatureProjector]#
- forward(past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_real: Optional[torch.Tensor], feat_static_cat: Optional[torch.Tensor], feat_dynamic_real: Optional[torch.Tensor], feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor] [source]#
定义每次调用时执行的计算。
应由所有子类覆盖。
注意
虽然前向传播的逻辑需要在此函数中定义,但之后应该调用
Module
实例而不是此函数本身,因为前者负责运行注册的钩子,而后者会默默忽略它们。
- loss(past_target: torch.Tensor, past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, feat_static_real: torch.Tensor, feat_static_cat: torch.Tensor, feat_dynamic_real: torch.Tensor, feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) torch.Tensor [source]#
- past_feat_dynamic_embed: Optional[FeatureEmbedder]#
- past_feat_dynamic_proj: Optional[FeatureProjector]#
- training: bool#