gluonts.torch.model.tft.module 模块#

class gluonts.torch.model.tft.module.TemporalFusionTransformerModel(context_length: int, prediction_length: int, d_feat_static_real: Optional[List[int]] = None, c_feat_static_cat: Optional[List[int]] = None, d_feat_dynamic_real: Optional[List[int]] = None, c_feat_dynamic_cat: Optional[List[int]] = None, d_past_feat_dynamic_real: Optional[List[int]] = None, c_past_feat_dynamic_cat: Optional[List[int]] = None, num_heads: int = 4, d_hidden: int = 32, d_var: int = 32, dropout_rate: float = 0.1, distr_output: Optional[gluonts.torch.distributions.output.Output] = None)[source]#

基类:torch.nn.modules.module.Module

时序融合Transformer神经网络。

部分基于 github.com/kashif/pytorch-transformer-ts 中的实现。

输入 feat_static_real, feat_static_cat 和 feat_dynamic_real 是必需的。输入 feat_dynamic_cat, past_feat_dynamic_real 和 past_feat_dynamic_cat 是可选的。

describe_inputs(batch_size=1) gluonts.model.inputs.InputSpec[source]#
feat_dynamic_embed: Optional[FeatureEmbedder]#
feat_dynamic_proj: Optional[FeatureProjector]#
feat_static_embed: Optional[FeatureEmbedder]#
feat_static_proj: Optional[FeatureProjector]#
forward(past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_real: Optional[torch.Tensor], feat_static_cat: Optional[torch.Tensor], feat_dynamic_real: Optional[torch.Tensor], feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor][source]#

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

虽然前向传播的逻辑需要在此函数中定义,但之后应该调用 Module 实例而不是此函数本身,因为前者负责运行注册的钩子,而后者会默默忽略它们。

input_types() Dict[str, torch.dtype][source]#
loss(past_target: torch.Tensor, past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, feat_static_real: torch.Tensor, feat_static_cat: torch.Tensor, feat_dynamic_real: torch.Tensor, feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) torch.Tensor[source]#
past_feat_dynamic_embed: Optional[FeatureEmbedder]#
past_feat_dynamic_proj: Optional[FeatureProjector]#
training: bool#