6
6
import jax .numpy as jnp
7
7
from jax import lax
8
8
from jax import nn as jnn
9
- from jax .experimental .pallas .ops .tpu import (
10
- flash_attention as flash_attention_tpu ,
9
+ from jax .experimental .pallas .ops .tpu .splash_attention import (
10
+ splash_attention_kernel ,
11
+ )
12
+ from jax .experimental .pallas .ops .tpu .splash_attention import (
13
+ splash_attention_mask ,
11
14
)
12
15
13
16
from keras .src import backend
@@ -1036,6 +1039,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
1036
1039
)
1037
1040
return False
1038
1041
1042
+ if jax .devices ()[0 ].platform == "tpu" :
1043
+ return True
1039
1044
try :
1040
1045
# Check if cuDNN is installed and raise RuntimeError if cuDNN is not
1041
1046
# detected
@@ -1109,6 +1114,38 @@ def _dot_product_attention_core(
1109
1114
return jnp .einsum ("BNTS,BSNH->BTNH" , probs , value )
1110
1115
1111
1116
1117
+ def wrap_flash_attention (
1118
+ query , key , value , decoder_segment_ids , custom_mask = None
1119
+ ):
1120
+ if decoder_segment_ids is not None :
1121
+ assert query .shape [2 ] == decoder_segment_ids .q .shape [1 ], (
1122
+ "Sharding along sequence dimension not allowed in tpu kernel "
1123
+ "attention"
1124
+ )
1125
+
1126
+ if custom_mask is not None :
1127
+ mask = splash_attention_mask .NumpyMask (mask = custom_mask )
1128
+
1129
+ else :
1130
+ mask = splash_attention_mask .CausalMask (
1131
+ shape = (query .shape [2 ], query .shape [2 ])
1132
+ )
1133
+
1134
+ # Create multi-head mask
1135
+ multi_head_mask = splash_attention_mask .MultiHeadMask (
1136
+ masks = (mask ,) * query .shape [1 ]
1137
+ )
1138
+ splash_kernel = splash_attention_kernel .make_splash_mha (
1139
+ mask = multi_head_mask ,
1140
+ head_shards = 1 ,
1141
+ q_seq_shards = 1 ,
1142
+ )
1143
+
1144
+ return jax .vmap (splash_kernel )(
1145
+ query , key , value , segment_ids = decoder_segment_ids
1146
+ )
1147
+
1148
+
1112
1149
def dot_product_attention (
1113
1150
query ,
1114
1151
key ,
@@ -1134,17 +1171,26 @@ def dot_product_attention(
1134
1171
# Use `raise_error=True` to provide more details if the inputs failed to
1135
1172
# use flash attention
1136
1173
_can_use_flash_attention (query , key , value , bias , raise_error = True )
1137
- if jax .devices ()[0 ].platform == "tpu" and flash_attention :
1138
- # Use TPU-optimized flash attention from Pallas
1139
- return flash_attention_tpu (
1174
+
1175
+ if jax .devices ()[0 ].platform == "tpu" :
1176
+ # Transpose to ('batch', 'heads', 'length', 'kv')
1177
+ query = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1178
+ key = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
1179
+ value = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
1180
+ B , H , S , KV = query .shape
1181
+
1182
+ segment_ids = jnp .ones ([B , S ])
1183
+ # {token_ids, padding_mask, segment_ids} enable packing
1184
+ out = wrap_flash_attention (
1140
1185
query ,
1141
1186
key ,
1142
1187
value ,
1143
- ab = bias ,
1144
- segment_ids = mask ,
1145
- causal = is_causal ,
1146
- sm_scale = scale ,
1188
+ splash_attention_kernel .SegmentIds (segment_ids , segment_ids ),
1189
+ custom_mask = mask ,
1147
1190
)
1191
+ out = jnp .transpose (out , axes = (0 , 2 , 1 , 3 ))
1192
+ return out
1193
+
1148
1194
# `dot_product_attention` is only available in jax>=0.4.31
1149
1195
if hasattr (jax .nn , "dot_product_attention" ):
1150
1196
return jax .nn .dot_product_attention (
0 commit comments