gluonts.mx.block.rnn 模块#

class gluonts.mx.block.rnn.RNN(mode: str, num_hidden: int, num_layers: int, bidirectional: bool = False, **kwargs)[source]#

继承自: mxnet.gluon.block.HybridBlock

定义一个 RNN 模块。

参数
  • mode – RNN 的类型。可以是以下之一:rnn_relu(使用 relu 激活函数的 RNN)、rnn_tanh(使用 tanh 激活函数的 RNN)、lstm 或 gru。

  • num_hidden – 每个隐藏层的单元数量。

  • num_layers – 隐藏层的数量。

  • bidirectional – 切换是否使用双向 RNN 作为编码器。

hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][source]#
参数
  • F – 在 MXNet 中可以引用 Symbol API 或 NDArray API 的模块。

  • inputs – 输入张量,形状为 (batch_size, num_timesteps, num_dimensions)

返回值

RNN 输出,形状为 (batch_size, num_timesteps, num_dimensions)

返回值类型

张量