gluonts.mx.block.rnn 模块#
- class gluonts.mx.block.rnn.RNN(mode: str, num_hidden: int, num_layers: int, bidirectional: bool = False, **kwargs)[source]#
继承自:
mxnet.gluon.block.HybridBlock
定义一个 RNN 模块。
- 参数
mode – RNN 的类型。可以是以下之一:rnn_relu(使用 relu 激活函数的 RNN)、rnn_tanh(使用 tanh 激活函数的 RNN)、lstm 或 gru。
num_hidden – 每个隐藏层的单元数量。
num_layers – 隐藏层的数量。
bidirectional – 切换是否使用双向 RNN 作为编码器。
- hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol] [source]#
- 参数
F – 在 MXNet 中可以引用 Symbol API 或 NDArray API 的模块。
inputs – 输入张量,形状为 (batch_size, num_timesteps, num_dimensions)
- 返回值
RNN 输出,形状为 (batch_size, num_timesteps, num_dimensions)
- 返回值类型
张量