Skip to content

Commit 537c524

Browse files
divyashreepathihallirtg0795
authored andcommitted
Fix flash attention TPU error (#20994)
* Fix flash attention TPU error * fix space * fix default mask * update default mask if none check in wrapping function instead
1 parent 1154a6d commit 537c524

File tree

1 file changed

+55
-9
lines changed
  • keras/src/backend/jax

1 file changed

+55
-9
lines changed

keras/src/backend/jax/nn.py

+55-9
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
import jax.numpy as jnp
77
from jax import lax
88
from jax import nn as jnn
9-
from jax.experimental.pallas.ops.tpu import (
10-
flash_attention as flash_attention_tpu,
9+
from jax.experimental.pallas.ops.tpu.splash_attention import (
10+
splash_attention_kernel,
11+
)
12+
from jax.experimental.pallas.ops.tpu.splash_attention import (
13+
splash_attention_mask,
1114
)
1215

1316
from keras.src import backend
@@ -1036,6 +1039,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
10361039
)
10371040
return False
10381041

1042+
if jax.devices()[0].platform == "tpu":
1043+
return True
10391044
try:
10401045
# Check if cuDNN is installed and raise RuntimeError if cuDNN is not
10411046
# detected
@@ -1109,6 +1114,38 @@ def _dot_product_attention_core(
11091114
return jnp.einsum("BNTS,BSNH->BTNH", probs, value)
11101115

11111116

1117+
def wrap_flash_attention(
1118+
query, key, value, decoder_segment_ids, custom_mask=None
1119+
):
1120+
if decoder_segment_ids is not None:
1121+
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
1122+
"Sharding along sequence dimension not allowed in tpu kernel "
1123+
"attention"
1124+
)
1125+
1126+
if custom_mask is not None:
1127+
mask = splash_attention_mask.NumpyMask(mask=custom_mask)
1128+
1129+
else:
1130+
mask = splash_attention_mask.CausalMask(
1131+
shape=(query.shape[2], query.shape[2])
1132+
)
1133+
1134+
# Create multi-head mask
1135+
multi_head_mask = splash_attention_mask.MultiHeadMask(
1136+
masks=(mask,) * query.shape[1]
1137+
)
1138+
splash_kernel = splash_attention_kernel.make_splash_mha(
1139+
mask=multi_head_mask,
1140+
head_shards=1,
1141+
q_seq_shards=1,
1142+
)
1143+
1144+
return jax.vmap(splash_kernel)(
1145+
query, key, value, segment_ids=decoder_segment_ids
1146+
)
1147+
1148+
11121149
def dot_product_attention(
11131150
query,
11141151
key,
@@ -1134,17 +1171,26 @@ def dot_product_attention(
11341171
# Use `raise_error=True` to provide more details if the inputs failed to
11351172
# use flash attention
11361173
_can_use_flash_attention(query, key, value, bias, raise_error=True)
1137-
if jax.devices()[0].platform == "tpu" and flash_attention:
1138-
# Use TPU-optimized flash attention from Pallas
1139-
return flash_attention_tpu(
1174+
1175+
if jax.devices()[0].platform == "tpu":
1176+
# Transpose to ('batch', 'heads', 'length', 'kv')
1177+
query = jnp.transpose(query, axes=(0, 2, 1, 3))
1178+
key = jnp.transpose(key, axes=(0, 2, 1, 3))
1179+
value = jnp.transpose(value, axes=(0, 2, 1, 3))
1180+
B, H, S, KV = query.shape
1181+
1182+
segment_ids = jnp.ones([B, S])
1183+
# {token_ids, padding_mask, segment_ids} enable packing
1184+
out = wrap_flash_attention(
11401185
query,
11411186
key,
11421187
value,
1143-
ab=bias,
1144-
segment_ids=mask,
1145-
causal=is_causal,
1146-
sm_scale=scale,
1188+
splash_attention_kernel.SegmentIds(segment_ids, segment_ids),
1189+
custom_mask=mask,
11471190
)
1191+
out = jnp.transpose(out, axes=(0, 2, 1, 3))
1192+
return out
1193+
11481194
# `dot_product_attention` is only available in jax>=0.4.31
11491195
if hasattr(jax.nn, "dot_product_attention"):
11501196
return jax.nn.dot_product_attention(

0 commit comments

Comments
 (0)