gluonts.torch.model.predictor 模块#
- class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device: Union[str, torch.device] = 'auto')[source]#
- 基类: - gluonts.model.predictor.RepresentablePredictor- classmethod deserialize(path: pathlib.Path, device: Optional[Union[torch.device, str]] = None) gluonts.torch.model.predictor.PyTorchPredictor[source]#
- 从给定路径加载序列化的预测器。 - 参数
- path – 序列化预测器文件的路径。 
- **kwargs – 可选的上下文/设备参数,用于预测器。如果未传递任何内容,将优先使用 GPU,否则使用 CPU。 
 
 
 - property network: torch.nn.modules.module.Module#
 - predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast][source]#
- 计算给定数据集中时间序列的预测结果。此方法未在此抽象类中实现;请使用其子类之一。 :param dataset: 包含要预测的时间序列的数据集。 - 返回
- 预测结果的迭代器,顺序与提供的数据集迭代器相同。 
- 返回类型
- Iterator[Forecast] 
 
 - to(device: Union[str, torch.device]) gluonts.torch.model.predictor.PyTorchPredictor[source]#