Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion paddlex/inference/models/common/static_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _create(
logging.debug("`device_id` has been set to None")

if (
self._option.device_type in ("gpu", "dcu", "npu", "mlu", "gcu", "xpu")
self._option.device_type in ("gpu", "dcu", "npu", "mlu", "gcu", "xpu", "iluvatar_gpu")
and self._option.device_id is None
):
self._option.device_id = 0
Expand Down Expand Up @@ -447,6 +447,12 @@ def _create(
# Delete unsupported passes in dcu
config.delete_pass("conv2d_add_act_fuse_pass")
config.delete_pass("conv2d_add_fuse_pass")
elif self._option.device_type == "iluvatar_gpu":
config.enable_custom_device("iluvatar_gpu", int(self._option.device_id))
if hasattr(config, "enable_new_ir"):
config.enable_new_ir(self._option.enable_new_ir)
if hasattr(config, "enable_new_executor"):
config.enable_new_executor()
else:
assert self._option.device_type == "cpu"
config.disable_gpu()
Expand Down
8 changes: 5 additions & 3 deletions paddlex/inference/models/common/vlm/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def get_env_device():
return "gcu"
elif "intel_hpu" in paddle.device.get_all_custom_device_type():
return "intel_hpu"
elif "iluvatar_gpu" in paddle.device.get_all_custom_device_type():
return "iluvatar_gpu"
elif paddle.is_compiled_with_rocm():
return "rocm"
elif paddle.is_compiled_with_xpu():
Expand All @@ -61,7 +63,7 @@ def get_env_device():
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() in ["npu", "mlu", "gcu"]:
if get_env_device() in ["npu", "mlu", "gcu", "iluvatar_gpu"]:
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
Expand All @@ -84,7 +86,7 @@ def fusion_rope(
rotary_emb,
context_parallel_degree=-1,
):
if get_env_device() not in ["gcu", "intel_hpu"]:
if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
Expand All @@ -93,7 +95,7 @@ def fusion_rope(
get_env_device() == "gpu"
), "context parallel only support cuda device for now"
kv_seq_len *= context_parallel_degree
if get_env_device() not in ["gcu", "intel_hpu"]:
if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
Expand Down
2 changes: 1 addition & 1 deletion paddlex/inference/utils/pp_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class PaddlePredictorOption(object):
"mkldnn",
"mkldnn_bf16",
)
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu")
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu", "iluvatar_gpu")

def __init__(self, **kwargs):
super().__init__()
Expand Down
2 changes: 2 additions & 0 deletions paddlex/repo_apis/PaddleOCR_api/text_rec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def update_device(self, device: str):
"Global.use_npu": False,
"Global.use_mlu": False,
"Global.use_gcu": False,
"Global.use_iluvatar_gpu": False,
}

device_cfg = {
Expand All @@ -258,6 +259,7 @@ def update_device(self, device: str):
"mlu": {"Global.use_mlu": True},
"npu": {"Global.use_npu": True},
"gcu": {"Global.use_gcu": True},
"iluvatar_gpu": {"Global.use_iluvatar_gpu": True},
}
default_cfg.update(device_cfg[device])
self.update(default_cfg)
Expand Down
2 changes: 1 addition & 1 deletion paddlex/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .flags import DISABLE_DEV_MODEL_WL

SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu"]
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu", "iluvatar_gpu"]


def constr_device(device_type, device_ids):
Expand Down