gluonts.mx.block.dropout 模块#
- class gluonts.mx.block.dropout.RNNZoneoutCell(base_cell: mxnet.gluon.rnn.rnn_cell.RecurrentCell, zoneout_outputs: float = 0.0, zoneout_states: float = 0.0)[source]#
基类:
mxnet.gluon.rnn.rnn_cell.ModifierCell
在基础单元上应用 Zoneout。实现遵循 [KMK16]。
与 mx.gluon.rnn.ZoneoutCell 相比,此实现对输出和 states[0] 使用相同的掩码,因为对于 RNN 单元,states[0] 与输出相同,除了 ResidualCell,其中 states[0] = input + ouptput
- 参数
base_cell – 应用变分 Zoneout 的单元。
zoneout_outputs – 输出的 dropout 率。如果为 0 则不应用 dropout。
zoneout_states – 第一个状态通道上状态输入的 dropout 率。如果为 0 则不应用 dropout。
- hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], states: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Tuple[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]] [source]#
覆盖此 Block 以构建符号图。
- 参数
x (Symbol 或 NDArray) – 第一个输入张量。
*args (Symbol 列表 或 NDArray 列表) – 附加输入张量。
- class gluonts.mx.block.dropout.VariationalZoneoutCell(base_cell: mxnet.gluon.rnn.rnn_cell.RecurrentCell, zoneout_outputs: float = 0.0, zoneout_states: float = 0.0)[source]#
基类:
mxnet.gluon.rnn.rnn_cell.ModifierCell
在基础单元上应用变分 Zoneout。实现遵循。
[GG16]。变分 Zoneout 在不同时间步使用相同的掩码。它可以应用于 RNN 输出和状态。它们的掩码不共享。
掩码在第一次向前步进时初始化,并且在调用 .reset() 之前将保持不变。因此,如果手动使用单元并进行步进而不调用 .unroll(),则应在每个序列之后调用 .reset()。
- 参数
base_cell – 应用变分 Zoneout 的单元。
zoneout_outputs – 输出的 dropout 率。如果为 0 则不应用 dropout。
zoneout_states – 第一个状态通道上状态输入的 dropout 率。如果为 0 则不应用 dropout。
- hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], states: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Tuple[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]] [source]#
覆盖此 Block 以构建符号图。
- 参数
x (Symbol 或 NDArray) – 第一个输入张量。
*args (Symbol 列表 或 NDArray 列表) – 附加输入张量。