gluonts.mx.trainer.model_averaging module#

class gluonts.mx.trainer.model_averaging.AveragingStrategy(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

基类: object

apply(model_path: str) str[source]#

根据选定的模型策略和度量标准,对序列化模型的参数进行平均。重要提示:根据度量标准,用户可能希望最小化或最大化。必须适当选择 maximize 标志来反映这一点。

参数

model_path – 模型目录的路径。

返回类型

包含平均模型的文件的路径。

average(param_paths: List[str], weights: List[float]) Dict[source]#

对 .params 文件路径列表中的参数进行平均。

参数
  • param_paths – 参数文件路径列表。

  • weights – 参数平均的权重列表。

返回类型

平均参数字典。

static average_arrays(arrays: List[mxnet.ndarray.ndarray.NDArray], weights: List[float]) mxnet.ndarray.ndarray.NDArray[source]#

接收一个形状相同的数组列表,并计算按元素的加权平均值。

参数
  • arrays – 将进行平均的、形状相同的 NDArrays 列表。

  • weights – 参数平均的权重列表。

返回类型

与 arrays[0] 位于相同上下文中的 NDArrays 平均值。

static get_checkpoint_information(model_path: str) List[Dict][source]#
参数

model_path – 模型目录的路径。

返回

  • 检查点信息字典列表(度量标准, epoch_no,

  • 检查点路径)。

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

选择检查点并计算所选检查点的权重。

参数

checkpoints – 检查点信息字典列表。

返回

  • 所选检查点路径列表以及对应的

  • 权重列表。

class gluonts.mx.trainer.model_averaging.ModelAveraging(avg_strategy: gluonts.mx.trainer.model_averaging.AveragingStrategy)[source]#

基类: gluonts.mx.trainer.callback.Callback

实现模型平均策略的回调函数。根据所选的 avg_strategy,选择损失值最佳的检查点,并计算模型平均或加权模型平均。

参数

avg_strategy – AveragingStrategy,来自 gluonts.mx.trainer.model_averaging 的 SelectNBestSoftmax 或 SelectNBestMean 之一。

on_train_end(training_network: mxnet.gluon.block.HybridBlock, temporary_dir: str, ctx: Optional[mxnet.context.Context] = None) None[source]#

训练结束时调用的钩子。这是最后一个被调用的钩子。

参数
  • training_network – 已训练的网络。

  • temporary_dir – 训练过程中记录模型参数的目录。

  • ctx – 使用的 MXNet 上下文。

class gluonts.mx.trainer.model_averaging.SelectNBestMean(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

基类: gluonts.mx.trainer.model_averaging.AveragingStrategy

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

选择度量值最佳的检查点。所有检查点的权重相等,即 w_i = 1/N。

参数

checkpoints – 检查点信息字典列表。

返回

  • 所选检查点路径列表以及对应的

  • 权重列表。

class gluonts.mx.trainer.model_averaging.SelectNBestSoftmax(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

基类: gluonts.mx.trainer.model_averaging.AveragingStrategy

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

选择度量值最佳的检查点。权重是度量值的 softmax,即 maximize=True 时 w_i = exp(v_i) / sum(exp(v_j));maximize=False 时 w_i = exp(-v_i) / sum(exp(-v_j))。

参数

checkpoints – 检查点信息字典列表。

返回

  • 所选检查点路径列表以及对应的

  • 权重列表。

gluonts.mx.trainer.model_averaging.save_epoch_info(tmp_path: str, epoch_info: dict) None[source]#

将当前 epoch 信息写入模型路径中的 json 文件。

参数
  • tmp_path – 保存 epoch 信息的临时基础路径。

  • epoch_info – 包含参数路径、epoch 编号和跟踪度量值的 epoch 信息字典。

返回类型