gluonts.torch.distributions.discrete_distribution 模块#
- class gluonts.torch.distributions.discrete_distribution.DiscreteDistribution(values: torch.Tensor, probs: torch.Tensor, validate_args: Optional[bool] = None)[source]#
基类:
torch.distributions.distribution.Distribution
实现离散分布,其中底层随机变量从有限集合 values 中取值,并具有相应的概率。
注意:values 中可以有重复项,在这种情况下,重复项的概率质量将被累加。
一个自然的损失函数,特别是因为新的观测值不一定来自有限集合 values,是排序概率得分(RPS)。
- 因此,为了与其他模型的术语保持一致,log_prob 实现为负 RPS。
static adjust_probs(values_sorted, probs_sorted)[source]#
将所有重复值的概率质量放在一个位置(重复项的最后一个索引)。
- 假设:values_sorted 已排序!
参数
values_sorted –
- probs_sorted –
- 返回分布的均值。