From 0528872525f9dd436dc0ac5e1f90302113a4d152 Mon Sep 17 00:00:00 2001 From: Aleksandre Kandelaki Date: Fri, 16 May 2025 14:50:36 +0200 Subject: [PATCH] Bugfix: Attention computation is not equivalent if fused_attn is false --- vggt/layers/attention.py | 60 ++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/vggt/layers/attention.py b/vggt/layers/attention.py index ab3089ce..f92740de 100644 --- a/vggt/layers/attention.py +++ b/vggt/layers/attention.py @@ -20,23 +20,23 @@ class Attention(nn.Module): def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - norm_layer: nn.Module = nn.LayerNorm, - qk_norm: bool = False, - fused_attn: bool = True, # use F.scaled_dot_product_attention or not - rope=None, + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.head_dim ** -0.5 self.fused_attn = fused_attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -47,6 +47,34 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) self.rope = rope + @staticmethod + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + def forward(self, x: Tensor, pos=None) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) @@ -65,11 +93,9 @@ def forward(self, x: Tensor, pos=None) -> Tensor: dropout_p=self.attn_drop.p if self.training else 0.0, ) else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v + x = self.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.attn_drop.p if self.training else 0.0, + is_causal=False, scale=self.scale, enable_gqa=False) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x)