@@ -49,6 +49,8 @@ def get_env_device():
49
49
return "gcu"
50
50
elif "intel_hpu" in paddle .device .get_all_custom_device_type ():
51
51
return "intel_hpu"
52
+ elif "iluvatar_gpu" in paddle .device .get_all_custom_device_type ():
53
+ return "iluvatar_gpu"
52
54
elif paddle .is_compiled_with_rocm ():
53
55
return "rocm"
54
56
elif paddle .is_compiled_with_xpu ():
@@ -61,7 +63,7 @@ def get_env_device():
61
63
except ImportError :
62
64
fused_rotary_position_embedding = None
63
65
try :
64
- if get_env_device () in ["npu" , "mlu" , "gcu" ]:
66
+ if get_env_device () in ["npu" , "mlu" , "gcu" , "iluvatar_gpu" ]:
65
67
from paddle .base import core
66
68
67
69
for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
@@ -84,7 +86,7 @@ def fusion_rope(
84
86
rotary_emb ,
85
87
context_parallel_degree = - 1 ,
86
88
):
87
- if get_env_device () not in ["gcu" , "intel_hpu" ]:
89
+ if get_env_device () not in ["gcu" , "intel_hpu" , "iluvatar_gpu" ]:
88
90
assert past_key_value is None , "fuse rotary not support cache kv for now"
89
91
batch_size , seq_length , num_heads , head_dim = query_states .shape
90
92
_ , kv_seq_len , num_key_value_heads , _ = key_states .shape
@@ -93,7 +95,7 @@ def fusion_rope(
93
95
get_env_device () == "gpu"
94
96
), "context parallel only support cuda device for now"
95
97
kv_seq_len *= context_parallel_degree
96
- if get_env_device () not in ["gcu" , "intel_hpu" ]:
98
+ if get_env_device () not in ["gcu" , "intel_hpu" , "iluvatar_gpu" ]:
97
99
cos , sin = rotary_emb (value_states , seq_len = kv_seq_len )
98
100
if get_env_device () == "npu" :
99
101
query_states = core .eager ._run_custom_op ("fused_rope" , query_states , cos , sin )[
0 commit comments