Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import contextlib
from dataclasses import dataclass
from typing import Optional, Tuple
from functools import partial

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -164,6 +165,12 @@ def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
"""
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])

def document_causal_mask(docs, b, h, q_idx, kv_idx):
"""Creates a causal mask for documents, as well as ensuring that tokens which attend to each other belong to the same document."""
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask


def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
"""Creates a block mask from a list of sequence lengths.
Expand All @@ -181,13 +188,10 @@ def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
docs = seqlens_to_docs_tensor(seqlens).to("cuda")
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask
partial_mask = partial(document_causal_mask, docs)

return create_block_mask(
document_causal_mask,
partial_mask,
batch_size,
None,
max_seq_len,
Expand Down