gluonts.mx.block.rnn 模块#
- class gluonts.mx.block.rnn.RNN(mode: str, num_hidden: int, num_layers: int, bidirectional: bool = False, **kwargs)[source]#
- 继承自: - mxnet.gluon.block.HybridBlock- 定义一个 RNN 模块。 - 参数
- mode – RNN 的类型。可以是以下之一:rnn_relu(使用 relu 激活函数的 RNN)、rnn_tanh(使用 tanh 激活函数的 RNN)、lstm 或 gru。 
- num_hidden – 每个隐藏层的单元数量。 
- num_layers – 隐藏层的数量。 
- bidirectional – 切换是否使用双向 RNN 作为编码器。 
 
 - hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][source]#
- 参数
- F – 在 MXNet 中可以引用 Symbol API 或 NDArray API 的模块。 
- inputs – 输入张量,形状为 (batch_size, num_timesteps, num_dimensions) 
 
- 返回值
- RNN 输出,形状为 (batch_size, num_timesteps, num_dimensions) 
- 返回值类型
- 张量