Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddlex/inference/models/doc_vlm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import List

from ....modules.doc_vlm.model_list import MODELS
from ....utils.device import TemporaryDeviceChanger
from ....utils.device import TemporaryDeviceChanger, constr_device
from ....utils.env import get_device_type
from ...common.batch_sampler import DocVLMBatchSampler
from ..base import BasePredictor
Expand All @@ -44,6 +44,9 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)
self.device = kwargs.get("device", None)
if self.device is None and self.pp_option is not None:
if self.pp_option.device_type is not None and self.pp_option.device_type != "cpu":
self.device = constr_device(self.pp_option.device_type, str(self.pp_option.device_id))
self.dtype = (
"bfloat16"
if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
Expand Down