下载此 Notebook

训练器回调#

本 Notebook 演示了如何通过向 Trainer 类提供回调来控制基于 MXNet 的模型的训练过程。回调是一个函数,它在训练期间的一个或多个特定钩子点被调用。您可以使用 GluonTS 预定义的回调,例如 TrainingHistoryModelAveragingTerminateOnNaN,也可以实现自己的回调。

[1]:
from gluonts.dataset.repository import get_dataset

dataset = get_dataset("m4_hourly")
prediction_length = dataset.metadata.prediction_length
freq = dataset.metadata.freq

使用单个回调#

要使用回调,只需在构造 Trainer 时将它们作为列表传递即可:在下面的示例中,我们使用 TrainingHistory 回调来记录训练期间测量的损失值。

[2]:
from gluonts.mx import SimpleFeedForwardEstimator, Trainer
from gluonts.mx.trainer.callback import TrainingHistory

# defining a callback, which will log the training loss for each epoch
history = TrainingHistory()

trainer = Trainer(epochs=3, callbacks=[history])
estimator = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, trainer=trainer
)

predictor = estimator.train(dataset.train, num_workers=None)
100%|██████████| 50/50 [00:00<00:00, 126.64it/s, epoch=1/3, avg_epoch_loss=5.55]
100%|██████████| 50/50 [00:00<00:00, 140.81it/s, epoch=2/3, avg_epoch_loss=4.7]
100%|██████████| 50/50 [00:00<00:00, 137.59it/s, epoch=3/3, avg_epoch_loss=4.54]

打印每个 epoch 的训练损失

[3]:
print(history.loss_history)
[5.546479229927063, 4.702160387039185, 4.540805015563965]

使用多个回调#

要从给定的预测器继续训练,您可以使用 WarmStart 回调。当您想使用多个回调时,只需提供一个包含多个回调对象的列表即可

[4]:
from gluonts.mx.trainer.callback import WarmStart

warm_start = WarmStart(predictor=predictor)

trainer = Trainer(epochs=3, callbacks=[history, warm_start])

estimator = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, trainer=trainer
)

predictor = estimator.train(dataset.train, num_workers=None)
100%|██████████| 50/50 [00:00<00:00, 133.93it/s, epoch=1/3, avg_epoch_loss=4.44]
100%|██████████| 50/50 [00:00<00:00, 138.14it/s, epoch=2/3, avg_epoch_loss=4.4]
100%|██████████| 50/50 [00:00<00:00, 142.33it/s, epoch=3/3, avg_epoch_loss=4.43]
[5]:
print(
    history.loss_history
)  # The training loss history of all 3+3 epochs we trained the model for
[5.546479229927063, 4.702160387039185, 4.540805015563965, 4.439644269943237, 4.402952268123626, 4.425053224563599]

默认回调#

除了您指定的回调之外,Trainer 类还使用两个默认回调 ModelAveragingLearningRateReduction`。您可以在初始化 Trainer 时通过设置 add_default_callbacks=False` 来关闭它们。

[6]:
trainer = Trainer(
    epochs=20, callbacks=[history]
)  # use the TrainingHistory Callback and the default callbacks.
trainer = Trainer(
    epochs=20, callbacks=[history], add_default_callbacks=False
)  # use only the TrainingHistory Callback
trainer = Trainer(epochs=20, add_default_callbacks=False)  # use no callback at all

自定义回调#

要实现您自己的回调,您可以编写一个继承自 gluonts.mx.trainer.Callback 的类,并覆盖一个或多个钩子。请查看抽象类 Callback,钩子接受不同的参数供您使用。返回布尔值的钩子方法如果返回 False,将停止训练。

这里是一个自定义回调实现的示例,它根据某个指标(例如 RMSE)的值提前终止训练。它只实现了钩子方法 on_epoch_end,该方法在处理完一个 epoch 的所有批次后被调用。

[7]:
import numpy as np
import mxnet as mx

from gluonts.evaluation import Evaluator
from gluonts.dataset.common import Dataset
from gluonts.mx import copy_parameters, GluonPredictor
from gluonts.mx.trainer.callback import Callback


class MetricInferenceEarlyStopping(Callback):
    """
    Early Stopping mechanism based on the prediction network.
    Can be used to base the Early Stopping directly on a metric of interest, instead of on the training/validation loss.
    In the same way as test datasets are used during model evaluation,
    the time series of the validation_dataset can overlap with the train dataset time series,
    except for a prediction_length part at the end of each time series.

    Parameters
    ----------
    validation_dataset
        An out-of-sample dataset which is used to monitor metrics
    predictor
        A gluon predictor, with a prediction network that matches the training network
    evaluator
        The Evaluator used to calculate the validation metrics.
    metric
        The metric on which to base the early stopping on.
    patience
        Number of epochs to train on given the metric did not improve more than min_delta.
    min_delta
        Minimum change in the monitored metric counting as an improvement
    verbose
        Controls, if the validation metric is printed after each epoch.
    minimize_metric
        The metric objective.
    restore_best_network
        Controls, if the best model, as assessed by the validation metrics is restored after training.
    num_samples
        The amount of samples drawn to calculate the inference metrics.
    """

    def __init__(
        self,
        validation_dataset: Dataset,
        predictor: GluonPredictor,
        evaluator: Evaluator = Evaluator(num_workers=None),
        metric: str = "MSE",
        patience: int = 10,
        min_delta: float = 0.0,
        verbose: bool = True,
        minimize_metric: bool = True,
        restore_best_network: bool = True,
        num_samples: int = 100,
    ):
        assert patience >= 0, "EarlyStopping Callback patience needs to be >= 0"
        assert min_delta >= 0, "EarlyStopping Callback min_delta needs to be >= 0.0"
        assert num_samples >= 1, "EarlyStopping Callback num_samples needs to be >= 1"

        self.validation_dataset = list(validation_dataset)
        self.predictor = predictor
        self.evaluator = evaluator
        self.metric = metric
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.restore_best_network = restore_best_network
        self.num_samples = num_samples

        if minimize_metric:
            self.best_metric_value = np.inf
            self.is_better = np.less
        else:
            self.best_metric_value = -np.inf
            self.is_better = np.greater

        self.validation_metric_history: List[float] = []
        self.best_network = None
        self.n_stale_epochs = 0

    def on_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: mx.gluon.nn.HybridBlock,
        trainer: mx.gluon.Trainer,
        best_epoch_info: dict,
        ctx: mx.Context,
    ) -> bool:
        should_continue = True
        copy_parameters(training_network, self.predictor.prediction_net)

        from gluonts.evaluation.backtest import make_evaluation_predictions

        forecast_it, ts_it = make_evaluation_predictions(
            dataset=self.validation_dataset,
            predictor=self.predictor,
            num_samples=self.num_samples,
        )

        agg_metrics, item_metrics = self.evaluator(ts_it, forecast_it)
        current_metric_value = agg_metrics[self.metric]
        self.validation_metric_history.append(current_metric_value)

        if self.verbose:
            print(
                f"Validation metric {self.metric}: {current_metric_value}, best: {self.best_metric_value}"
            )

        if self.is_better(current_metric_value, self.best_metric_value):
            self.best_metric_value = current_metric_value

            if self.restore_best_network:
                training_network.save_parameters("best_network.params")

            self.n_stale_epochs = 0
        else:
            self.n_stale_epochs += 1
            if self.n_stale_epochs == self.patience:
                should_continue = False
                print(
                    f"EarlyStopping callback initiated stop of training at epoch {epoch_no}."
                )

                if self.restore_best_network:
                    print(
                        f"Restoring best network from epoch {epoch_no - self.patience}."
                    )
                    training_network.load_parameters("best_network.params")

        return should_continue

我们现在可以如下使用自定义回调。请注意,我们运行的 epoch 数量非常少,仅仅是为了保持 notebook 的运行时间可管理:请随意增加 epoch 数量以充分测试回调的有效性。

[8]:
estimator = SimpleFeedForwardEstimator(prediction_length=prediction_length)
training_network = estimator.create_training_network()
transformation = estimator.create_transformation()

predictor = estimator.create_predictor(
    transformation=transformation, trained_network=training_network
)

es_callback = MetricInferenceEarlyStopping(
    validation_dataset=dataset.test, predictor=predictor, metric="MSE"
)

trainer = Trainer(epochs=5, callbacks=[es_callback])

estimator.trainer = trainer

pred = estimator.train(dataset.train)
100%|██████████| 50/50 [00:00<00:00, 136.43it/s, epoch=1/5, avg_epoch_loss=5.55]
Running evaluation: 414it [00:02, 153.83it/s]
Validation metric MSE: 16590203.479222953, best: inf
100%|██████████| 50/50 [00:00<00:00, 137.55it/s, epoch=2/5, avg_epoch_loss=4.69]
Running evaluation: 414it [00:02, 156.87it/s]
Validation metric MSE: 9028248.932885194, best: 16590203.479222953
100%|██████████| 50/50 [00:00<00:00, 139.02it/s, epoch=3/5, avg_epoch_loss=4.79]
Running evaluation: 414it [00:02, 157.63it/s]
Validation metric MSE: 16308248.984650122, best: 9028248.932885194
100%|██████████| 50/50 [00:00<00:00, 134.38it/s, epoch=4/5, avg_epoch_loss=4.62]
Running evaluation: 414it [00:02, 157.42it/s]
Validation metric MSE: 10582128.785360953, best: 9028248.932885194
100%|██████████| 50/50 [00:00<00:00, 138.63it/s, epoch=5/5, avg_epoch_loss=4.3]
Running evaluation: 414it [00:02, 157.09it/s]
Validation metric MSE: 10019828.282515068, best: 9028248.932885194