gluonts.torch.distributions.truncated_normal 模块#
- class gluonts.torch.distributions.truncated_normal.TruncatedNormal(loc: torch.Tensor, scale: torch.Tensor, min: Union[torch.Tensor, float] = - 1.0, max: Union[torch.Tensor, float] = 1.0, upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False)[source]#
- 基类: - torch.distributions.distribution.Distribution- 实现了一个带有位置缩放的截断正态分布。 - 位置缩放可以防止位置参数“离0太远”,因为这最终会导致数值不稳定的样本和糟糕的梯度计算(例如梯度爆炸)。实际上,位置参数的计算方式如下: \[loc = tanh(loc / upscale) * upscale.\]- 可以通过关闭 tanh_loc 参数来禁用此行为(见下文)。 - 参数
- loc – 正态分布的位置参数 
- scale – 正态分布的 sigma 参数(方差的平方根) 
- min – 分布的最小值。默认值 = -1.0 
- max – 分布的最大值。默认值 = 1.0 
- upscale – 缩放因子。默认值 = 5.0 
- tanh_loc – 如果为 - True,则使用上述公式进行位置缩放,否则保留原始值。默认值为- False
 
 - 参考资料 - 说明 - 此实现主要基于以下资料:
 - arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=1e-06)}#
 - entropy#
- 返回分布的熵,按 batch_shape 批量计算。 - 返回
- 形状为 batch_shape 的 Tensor。 
 
 - eps = 1e-06#
 - has_rsample = True#
 - mean#
- 返回分布的均值。 
 - rsample(sample_shape=None)[source]#
- 生成形状为 sample_shape 的可重参数化样本,如果分布参数是批量的,则生成形状为 sample_shape 的可重参数化样本批次。 
 - support#
- 返回一个表示此分布支持范围的 - Constraint对象。
 - variance#
- 返回分布的方差。 
 
- class gluonts.torch.distributions.truncated_normal.TruncatedNormalOutput(min: float = - 1.0, max: float = 1.0, upscale: float = 5.0, tanh_loc: bool = False)[source]#
- 基类: - gluonts.torch.distributions.distribution_output.DistributionOutput- distribution(distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None) torch.distributions.distribution.Distribution[source]#
- 根据构造函数的参数集合以及可选的 scale tensor,构造相关的分布。 - 参数
- distr_args – 底层 Distribution 类型的构造函数参数。 
- loc – 可选的 tensor,其形状与生成的分布的 batch_shape+event_shape 相同。 
- scale – 可选的 tensor,其形状与生成的分布的 batch_shape+event_shape 相同。 
 
 
 - classmethod domain_map(loc: torch.Tensor, scale: torch.Tensor)[source]#
- 将参数转换为正确的形状和域。 - 域取决于分布类型,而正确的形状是通过重塑尾部轴来获得的,以使返回的 tensors 定义具有正确 event_shape 的分布。 
 - event_shape: Tuple#
- 与输出对象兼容的每个单独事件的形状。