From 9e3a3fb9b731e1aa10024dca34f7413badfd69d9 Mon Sep 17 00:00:00 2001 From: mohbasit Date: Fri, 15 Aug 2025 20:41:59 +0000 Subject: [PATCH] make inner function partial --- src/zeroband/models/llama/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index d9650358..0e7c7da1 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -14,6 +14,7 @@ import contextlib from dataclasses import dataclass from typing import Optional, Tuple +from functools import partial import torch import torch.nn.functional as F @@ -164,6 +165,12 @@ def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor: """ return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens]) +def document_causal_mask(docs, b, h, q_idx, kv_idx): + """Creates a causal mask for documents, as well as ensuring that tokens which attend to each other belong to the same document.""" + causal_mask = q_idx >= kv_idx + document_mask = docs[b, q_idx] == docs[b, kv_idx] + return causal_mask & document_mask + def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask: """Creates a block mask from a list of sequence lengths. @@ -181,13 +188,10 @@ def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask: docs = seqlens_to_docs_tensor(seqlens).to("cuda") batch_size, max_seq_len = docs.shape - def document_causal_mask(b, h, q_idx, kv_idx): - causal_mask = q_idx >= kv_idx - document_mask = docs[b, q_idx] == docs[b, kv_idx] - return causal_mask & document_mask + partial_mask = partial(document_causal_mask, docs) return create_block_mask( - document_causal_mask, + partial_mask, batch_size, None, max_seq_len,