gluonts.torch.distributions.discrete_distribution 模块#

class gluonts.torch.distributions.discrete_distribution.DiscreteDistribution(values: torch.Tensor, probs: torch.Tensor, validate_args: Optional[bool] = None)[source]#

基类: torch.distributions.distribution.Distribution

实现离散分布,其中底层随机变量从有限集合 values 中取值,并具有相应的概率。

注意:values 中可以有重复项,在这种情况下,重复项的概率质量将被累加。

一个自然的损失函数,特别是因为新的观测值不一定来自有限集合 values,是排序概率得分(RPS)。

因此,为了与其他模型的术语保持一致,log_prob 实现为负 RPS。

static adjust_probs(values_sorted, probs_sorted)[source]#

将所有重复值的概率质量放在一个位置(重复项的最后一个索引)。

假设:values_sorted 已排序!
  • 参数

  • values_sorted

probs_sorted

返回

log_prob(obs: torch.Tensor)[source]#

假设:values_sorted 已排序!

返回在 obs 处评估的概率密度/质量函数的对数。

obs (Tensor) –

mean()[source]#

返回分布的均值。
quantile_losses(obs: torch.Tensor, quantiles: torch.Tensor, levels: torch.Tensor)[source]#

rps(obs: torch.Tensor, check_for_duplicates: bool = True)[source]#

实现排序概率得分,它是所有可能分位数的分位数损失之和。

假设:values_sorted 已排序!
  • 在这里,分位数数量是有限的,等于 (obs 的每个批次元素的) 唯一值数量。

  • obs

check_for_duplicates

sample(sample_shape=torch.Size([]))[source]#