gluonts.torch.distributions.truncated_normal 模块#
- class gluonts.torch.distributions.truncated_normal.TruncatedNormal(loc: torch.Tensor, scale: torch.Tensor, min: Union[torch.Tensor, float] = - 1.0, max: Union[torch.Tensor, float] = 1.0, upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False)[source]#
基类:
torch.distributions.distribution.Distribution
实现了一个带有位置缩放的截断正态分布。
位置缩放可以防止位置参数“离0太远”,因为这最终会导致数值不稳定的样本和糟糕的梯度计算(例如梯度爆炸)。实际上,位置参数的计算方式如下:
\[loc = tanh(loc / upscale) * upscale.\]可以通过关闭 tanh_loc 参数来禁用此行为(见下文)。
- 参数
loc – 正态分布的位置参数
scale – 正态分布的 sigma 参数(方差的平方根)
min – 分布的最小值。默认值 = -1.0
max – 分布的最大值。默认值 = 1.0
upscale – 缩放因子。默认值 = 5.0
tanh_loc – 如果为
True
,则使用上述公式进行位置缩放,否则保留原始值。默认值为False
参考资料
说明
- 此实现主要基于以下资料:
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=1e-06)}#
- entropy#
返回分布的熵,按 batch_shape 批量计算。
- 返回
形状为 batch_shape 的 Tensor。
- eps = 1e-06#
- has_rsample = True#
- mean#
返回分布的均值。
- rsample(sample_shape=None)[source]#
生成形状为 sample_shape 的可重参数化样本,如果分布参数是批量的,则生成形状为 sample_shape 的可重参数化样本批次。
- support#
返回一个表示此分布支持范围的
Constraint
对象。
- variance#
返回分布的方差。
- class gluonts.torch.distributions.truncated_normal.TruncatedNormalOutput(min: float = - 1.0, max: float = 1.0, upscale: float = 5.0, tanh_loc: bool = False)[source]#
基类:
gluonts.torch.distributions.distribution_output.DistributionOutput
- distribution(distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None) torch.distributions.distribution.Distribution [source]#
根据构造函数的参数集合以及可选的 scale tensor,构造相关的分布。
- 参数
distr_args – 底层 Distribution 类型的构造函数参数。
loc – 可选的 tensor,其形状与生成的分布的 batch_shape+event_shape 相同。
scale – 可选的 tensor,其形状与生成的分布的 batch_shape+event_shape 相同。
- classmethod domain_map(loc: torch.Tensor, scale: torch.Tensor)[source]#
将参数转换为正确的形状和域。
域取决于分布类型,而正确的形状是通过重塑尾部轴来获得的,以使返回的 tensors 定义具有正确 event_shape 的分布。
- event_shape: Tuple#
与输出对象兼容的每个单独事件的形状。