Skip to content

Commit fa06a81

Browse files
committed
[ILUVATAR_GPU] Support for iluvatar_gpu (#4565)
1 parent 4127933 commit fa06a81

File tree

5 files changed

+15
-5
lines changed

5 files changed

+15
-5
lines changed

paddlex/inference/models/common/static_infer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ def _create(
447447
# Delete unsupported passes in dcu
448448
config.delete_pass("conv2d_add_act_fuse_pass")
449449
config.delete_pass("conv2d_add_fuse_pass")
450+
elif self._option.device_type == "iluvatar_gpu":
451+
config.enable_custom_device("iluvatar_gpu", int(self._option.device_id))
452+
if hasattr(config, "enable_new_ir"):
453+
config.enable_new_ir(self._option.enable_new_ir)
454+
if hasattr(config, "enable_new_executor"):
455+
config.enable_new_executor()
450456
else:
451457
assert self._option.device_type == "cpu"
452458
config.disable_gpu()

paddlex/inference/models/common/vlm/fusion_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def get_env_device():
4949
return "gcu"
5050
elif "intel_hpu" in paddle.device.get_all_custom_device_type():
5151
return "intel_hpu"
52+
elif "iluvatar_gpu" in paddle.device.get_all_custom_device_type():
53+
return "iluvatar_gpu"
5254
elif paddle.is_compiled_with_rocm():
5355
return "rocm"
5456
elif paddle.is_compiled_with_xpu():
@@ -61,7 +63,7 @@ def get_env_device():
6163
except ImportError:
6264
fused_rotary_position_embedding = None
6365
try:
64-
if get_env_device() in ["npu", "mlu", "gcu"]:
66+
if get_env_device() in ["npu", "mlu", "gcu", "iluvatar_gpu"]:
6567
from paddle.base import core
6668

6769
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
@@ -84,7 +86,7 @@ def fusion_rope(
8486
rotary_emb,
8587
context_parallel_degree=-1,
8688
):
87-
if get_env_device() not in ["gcu", "intel_hpu"]:
89+
if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
8890
assert past_key_value is None, "fuse rotary not support cache kv for now"
8991
batch_size, seq_length, num_heads, head_dim = query_states.shape
9092
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
@@ -93,7 +95,7 @@ def fusion_rope(
9395
get_env_device() == "gpu"
9496
), "context parallel only support cuda device for now"
9597
kv_seq_len *= context_parallel_degree
96-
if get_env_device() not in ["gcu", "intel_hpu"]:
98+
if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
9799
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
98100
if get_env_device() == "npu":
99101
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[

paddlex/inference/utils/pp_option.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class PaddlePredictorOption(object):
5454
"mkldnn",
5555
"mkldnn_bf16",
5656
)
57-
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu")
57+
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu", "iluvatar_gpu")
5858

5959
def __init__(self, **kwargs):
6060
super().__init__()

paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def update_device(self, device: str):
249249
"Global.use_npu": False,
250250
"Global.use_mlu": False,
251251
"Global.use_gcu": False,
252+
"Global.use_iluvatar_gpu": False,
252253
}
253254

254255
device_cfg = {
@@ -258,6 +259,7 @@ def update_device(self, device: str):
258259
"mlu": {"Global.use_mlu": True},
259260
"npu": {"Global.use_npu": True},
260261
"gcu": {"Global.use_gcu": True},
262+
"iluvatar_gpu": {"Global.use_iluvatar_gpu": True},
261263
}
262264
default_cfg.update(device_cfg[device])
263265
self.update(default_cfg)

paddlex/utils/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from .flags import DISABLE_DEV_MODEL_WL
2727

28-
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu"]
28+
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu", "iluvatar_gpu"]
2929

3030

3131
def constr_device(device_type, device_ids):

0 commit comments

Comments
 (0)