diff --git a/lightglue/lightglue.py b/lightglue/lightglue.py index c81bd07..631411d 100644 --- a/lightglue/lightglue.py +++ b/lightglue/lightglue.py @@ -132,9 +132,12 @@ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: s = q.shape[-1] ** -0.5 sim = torch.einsum("...id,...jd->...ij", q, k) * s if mask is not None: - sim.masked_fill(~mask, -float("inf")) + sim = sim.masked_fill(~mask, -float("inf")) attn = F.softmax(sim, -1) - return torch.einsum("...ij,...jd->...id", attn, v) + m = torch.einsum("...ij,...jd->...id", attn, v) + if mask is not None: + m = m.nan_to_num() + return m class SelfBlock(nn.Module):