gluonts.ext.rotbaum 包#
- class gluonts.ext.rotbaum.TreeEstimator(**kwargs)[source]#
Bases:
gluonts.ext.rotbaum._estimator.ThirdPartyEstimator
- lead_time: int#
- prediction_length: int#
- class gluonts.ext.rotbaum.TreePredictor(freq: str, prediction_length: int, n_ignore_last: int = 0, lead_time: int = 0, max_n_datapts: int = 1000000, min_bin_size: int = 100, context_length: Optional[int] = None, use_feat_static_real: bool = False, use_past_feat_dynamic_real: bool = False, use_feat_dynamic_real: bool = False, use_feat_dynamic_cat: bool = False, cardinality: Union[List[int], gluonts.ext.rotbaum._preprocess.CardinalityLabel] = 'auto', one_hot_encode: bool = False, model_params: Optional[dict] = None, max_workers: Optional[int] = None, method: str = 'QRX', quantiles=None, subtract_mean: bool = True, count_nans: bool = False, model=None, seed=None)[source]#
Bases:
gluonts.model.predictor.RepresentablePredictor
一个预测器,它为预测范围内的每个时间步使用一个 QRX 模型。
(换句话说,总共有 prediction_length 个模型被训练。特别地,这个预测器不学习多元分布。) 这些模型的列表保存在 self.model_list 下。
- classmethod deserialize(path: pathlib.Path, **kwargs: Any) gluonts.ext.rotbaum._predictor.TreePredictor [source]#
此函数加载并返回序列化模型。
它使用序列化参数加载预测器类。然后通过读取 pickle 文件加载训练好的模型列表。
- explain(importance_type: str = 'gain', percentage: bool = True) gluonts.ext.rotbaum._types.ExplanationResult [source]#
此函数仅适用于
self.method == "QuantileRegression"
,并使用 lightgbm 的特征重要性功能。它计算预测范围内的分位数和时间戳的平均特征重要性;然后将这些平均值加到与“target”、“feat_static_real”、“feat_static_cat”、“past_feat_dynamic_real”、“feat_dynamic_real”、“feat_dynamic_cat”相关的所有特征坐标上。- 参数
importance_type (str) – “gain” 或 “split”。由于预测预测范围内更远时间戳的模型表现可能较差,因此最好给予这些模型的权重低于预测预测范围更近时间戳的模型。“split” 会给予相等权重,因此不太理想;而 “gain” 会自然地给予表现较差的模型更低的权重。
percentage (bool) – 结果是否应为百分比格式且总和为 1。默认值为 True
- 返回类型
ExplanationResult
- predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast] [source]#
返回一个字典,将每个分位数映射到浮点数列表,这些浮点数是按 (时间步长, 时间序列) 字典顺序遍历时该分位数对应的预测值。
因此:首先它会给出所有时间序列第一个时间步长的分位数预测,然后是所有时间序列第二个时间步长的预测,依此类推。