gluonts.ext.rotbaum 包#

class gluonts.ext.rotbaum.TreeEstimator(**kwargs)[source]#

Bases: gluonts.ext.rotbaum._estimator.ThirdPartyEstimator

lead_time: int#
prediction_length: int#
class gluonts.ext.rotbaum.TreePredictor(freq: str, prediction_length: int, n_ignore_last: int = 0, lead_time: int = 0, max_n_datapts: int = 1000000, min_bin_size: int = 100, context_length: Optional[int] = None, use_feat_static_real: bool = False, use_past_feat_dynamic_real: bool = False, use_feat_dynamic_real: bool = False, use_feat_dynamic_cat: bool = False, cardinality: Union[List[int], gluonts.ext.rotbaum._preprocess.CardinalityLabel] = 'auto', one_hot_encode: bool = False, model_params: Optional[dict] = None, max_workers: Optional[int] = None, method: str = 'QRX', quantiles=None, subtract_mean: bool = True, count_nans: bool = False, model=None, seed=None)[source]#

Bases: gluonts.model.predictor.RepresentablePredictor

一个预测器,它为预测范围内的每个时间步使用一个 QRX 模型。

(换句话说,总共有 prediction_length 个模型被训练。特别地,这个预测器不学习多元分布。) 这些模型的列表保存在 self.model_list 下。

classmethod deserialize(path: pathlib.Path, **kwargs: Any) gluonts.ext.rotbaum._predictor.TreePredictor[source]#

此函数加载并返回序列化模型。

它使用序列化参数加载预测器类。然后通过读取 pickle 文件加载训练好的模型列表。

explain(importance_type: str = 'gain', percentage: bool = True) gluonts.ext.rotbaum._types.ExplanationResult[source]#

此函数仅适用于 self.method == "QuantileRegression",并使用 lightgbm 的特征重要性功能。它计算预测范围内的分位数和时间戳的平均特征重要性;然后将这些平均值加到与“target”、“feat_static_real”、“feat_static_cat”、“past_feat_dynamic_real”、“feat_dynamic_real”、“feat_dynamic_cat”相关的所有特征坐标上。

参数
  • importance_type (str) – “gain” 或 “split”。由于预测预测范围内更远时间戳的模型表现可能较差,因此最好给予这些模型的权重低于预测预测范围更近时间戳的模型。“split” 会给予相等权重,因此不太理想;而 “gain” 会自然地给予表现较差的模型更低的权重。

  • percentage (bool) – 结果是否应为百分比格式且总和为 1。默认值为 True

返回类型

ExplanationResult

predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast][source]#

返回一个字典,将每个分位数映射到浮点数列表,这些浮点数是按 (时间步长, 时间序列) 字典顺序遍历时该分位数对应的预测值。

因此:首先它会给出所有时间序列第一个时间步长的分位数预测,然后是所有时间序列第二个时间步长的预测,依此类推。

serialize(path: pathlib.Path) None[source]#

此函数调用父类的 serialize() 方法以序列化类名、版本信息和构造函数参数。

它通过序列化 TreePredictor 生成的模型列表来持久化树预测器。

train(training_data, train_QRX_only_using_timestep: int = - 1)[source]#