gluonts.torch.model.predictor 模块#
- class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device: Union[str, torch.device] = 'auto')[source]#
基类:
gluonts.model.predictor.RepresentablePredictor
- classmethod deserialize(path: pathlib.Path, device: Optional[Union[torch.device, str]] = None) gluonts.torch.model.predictor.PyTorchPredictor [source]#
从给定路径加载序列化的预测器。
- 参数
path – 序列化预测器文件的路径。
**kwargs – 可选的上下文/设备参数,用于预测器。如果未传递任何内容,将优先使用 GPU,否则使用 CPU。
- property network: torch.nn.modules.module.Module#
- predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast] [source]#
计算给定数据集中时间序列的预测结果。此方法未在此抽象类中实现;请使用其子类之一。 :param dataset: 包含要预测的时间序列的数据集。
- 返回
预测结果的迭代器,顺序与提供的数据集迭代器相同。
- 返回类型
Iterator[Forecast]
- to(device: Union[str, torch.device]) gluonts.torch.model.predictor.PyTorchPredictor [source]#