diff --git a/paddlex/inference/models/doc_vlm/predictor.py b/paddlex/inference/models/doc_vlm/predictor.py index c7fb3a0c77..b53b233439 100644 --- a/paddlex/inference/models/doc_vlm/predictor.py +++ b/paddlex/inference/models/doc_vlm/predictor.py @@ -18,7 +18,7 @@ from typing import List from ....modules.doc_vlm.model_list import MODELS -from ....utils.device import TemporaryDeviceChanger +from ....utils.device import TemporaryDeviceChanger, constr_device from ....utils.env import get_device_type from ...common.batch_sampler import DocVLMBatchSampler from ..base import BasePredictor @@ -44,6 +44,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.device = kwargs.get("device", None) + if self.device is None and self.pp_option is not None: + if self.pp_option.device_type is not None: + device_ids = ( + None + if self.pp_option.device_id is None + else [self.pp_option.device_id] + ) + self.device = constr_device(self.pp_option.device_type, device_ids) self.dtype = ( "bfloat16" if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())