gluonts.mx.distribution.distribution_output 模块#

class gluonts.mx.distribution.distribution_output.ArgProj(args_dim: typing.Dict[str, int], domain_map: typing.Callable[[...], typing.Tuple[typing.Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]]], dtype: typing.Type = <class 'numpy.float32'>, prefix: typing.Optional[str] = None, **kwargs)[source]#

继承自: mxnet.gluon.block.HybridBlock

一个可用于将密集层投影到分布参数的块。

参数
  • dim_args – 一个字典,其键为字符串,值为整数,表示将传递给域映射的每个参数的维度,名称用作参数前缀。

  • domain_map – 一个函数,返回包含一个张量、一个函数或一个 HybridBlock 的元组。这将使用 num_args 参数调用,并且应该返回一个输出元组,该元组将在调用分布构造函数时使用。

hybrid_forward(F, x: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], **kwargs) Tuple[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]][source]#

覆盖以构建此 Block 的符号图。

参数
  • x (SymbolNDArray) – 第一个输入张量。

  • *args (Symbol 列表NDArray 列表) – 附加输入张量。

class gluonts.mx.distribution.distribution_output.DistributionOutput[source]#

继承自: gluonts.mx.distribution.distribution_output.Output

根据网络输出构建分布的类。

args_dim: Dict[str, int]#
distr_cls: type#
distribution(distr_args, loc: Optional[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]] = None, scale: Optional[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]] = None) gluonts.mx.distribution.distribution.Distribution[source]#

根据构造函数参数集合以及可选的尺度张量,构造关联的分布。

参数
  • distr_args – 底层 Distribution 类型的构造函数参数。

  • loc – 可选张量,形状与结果分布的 batch_shape+event_shape 相同。

  • scale – 可选张量,形状与结果分布的 batch_shape+event_shape 相同。

domain_map(F, *args: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol])[source]#

将参数转换为正确的形状和域。

域取决于分布类型,而正确的形状是通过重塑尾部轴获得的,以便返回的张量定义具有正确 event_shape 的分布。

property event_dim: int#

事件维度数量,即此对象构建的分布的 event_shape 元组的长度。

property event_shape: Tuple#

此对象构建的分布所考虑的每个独立事件的形状。

property value_in_support: float#

一个浮点数,在计算相应分布的对数损失时具有有效的数值;默认为 0.0。

在填充数据序列时将使用此值。

class gluonts.mx.distribution.distribution_output.Output[source]#

继承自: object

连接网络到某些输出的类。

args_dim: Dict[str, int]#
domain_map(F, *args: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol])[source]#
property dtype#
classmethod eps()[source]#
get_args_proj(prefix: Optional[str] = None) mxnet.gluon.block.HybridBlock[source]#