gluonts.torch.model.predictor 模块#

class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device: Union[str, torch.device] = 'auto')[source]#

基类: gluonts.model.predictor.RepresentablePredictor

classmethod deserialize(path: pathlib.Path, device: Optional[Union[torch.device, str]] = None) gluonts.torch.model.predictor.PyTorchPredictor[source]#

从给定路径加载序列化的预测器。

参数
  • path – 序列化预测器文件的路径。

  • **kwargs – 可选的上下文/设备参数,用于预测器。如果未传递任何内容,将优先使用 GPU,否则使用 CPU。

property network: torch.nn.modules.module.Module#
predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast][source]#

计算给定数据集中时间序列的预测结果。此方法未在此抽象类中实现;请使用其子类之一。 :param dataset: 包含要预测的时间序列的数据集。

返回

预测结果的迭代器,顺序与提供的数据集迭代器相同。

返回类型

Iterator[Forecast]

serialize(path: pathlib.Path) None[source]#
to(device: Union[str, torch.device]) gluonts.torch.model.predictor.PyTorchPredictor[source]#