gluonts.mx.distribution.iresnet 模块#

class gluonts.mx.distribution.iresnet.InvertibleResnetHybridBlock(event_shape, hidden_units: int = 16, num_hidden_layers: int = 1, num_inv_iters: int = 10, ignore_logdet: bool = False, activation: str = 'lipswish', num_power_iter: int = 1, flatten: bool = False, coeff: float = 0.9, use_caching: bool = True, *args, **kwargs)[源代码]#

基于 [BJC19],除了 f 和 f_inv 的顺序交换了。

property event_dim: int#

property event_shape#
f(x: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#
iResnet 的正向变换。

参数

x – 观测值

返回值

变换后的张量 ` ext{iResnet}(x)`

返回类型

张量

f_inv(y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#

iResnet 的逆变换

y – 输入张量

x – 观测值

变换后的张量 ` ext{iResnet}^{-1}(y)`

变换后的张量 ` ext{iResnet}(x)`

log_abs_det_jac(x: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#

张量

f_inv(y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#

与 iResnet 变换对应的雅可比行列式绝对值的对数。

x – 正向变换的输入或逆向变换的输出

x – 观测值
  • y – 正向变换的输出或逆向变换的输入

  • 当 x 作为输入或 y 作为输出时计算的雅可比

变换后的张量 ` ext{iResnet}(x)`

gluonts.mx.distribution.iresnet.iresnet(num_blocks: int, **block_kwargs) gluonts.mx.distribution.bijection.ComposedBijectionHybridBlock[源代码]#

张量

f_inv(y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#

num_blocks – iResnet 块的数量
x – 观测值
  • block_kwargs – 初始化每个块对象时传入的关键字参数

  • gluonts.mx.distribution.iresnet.log_abs_det(A: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#

矩阵 A 的绝对值的对数 :param A: 用于计算其对数绝对值的张量矩阵

行列式

张量

f_inv(y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][源代码]#