Skip to content

WIP: Multimodal Audio #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 34 commits into
base: soham/pixtral-support
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ec2c9fb
initial data prep
tobyzl2 May 7, 2025
82e4edb
audio dataset changes
tobyzl2 May 9, 2025
9f3e60e
merge
tobyzl2 May 9, 2025
0d1cd96
audio token computation
tobyzl2 May 9, 2025
33b138b
Merge branch 'soham/pixtral-support' of https://github.com/ServiceNow…
tobyzl2 May 9, 2025
40f3882
implement mm packing
tobyzl2 May 10, 2025
a035d0c
merge
tobyzl2 May 13, 2025
94e439c
data updates
tobyzl2 May 15, 2025
543fc0d
changes
tobyzl2 May 16, 2025
6bbdc94
merge
tobyzl2 May 20, 2025
1a20913
layer changes
tobyzl2 May 20, 2025
5ffacab
merge
tobyzl2 May 22, 2025
c5396bc
Merge branch 'soham/pixtral-support' of https://github.com/ServiceNow…
tobyzl2 May 22, 2025
7eea79b
update audio encoder
tobyzl2 May 28, 2025
daf98b3
audio transformer updates
tobyzl2 May 28, 2025
cd167fc
audio conversion
tobyzl2 May 28, 2025
80c0aa2
merge
tobyzl2 May 28, 2025
e0f7dfd
mm loss masking spans
tobyzl2 May 29, 2025
0ae74d1
add lr scale
tobyzl2 May 29, 2025
28a3808
merge
tobyzl2 May 29, 2025
438ba80
mel spec changes
tobyzl2 May 30, 2025
525543a
updates
tobyzl2 May 30, 2025
95526a3
adding audio start and end tokens
tobyzl2 Jun 2, 2025
01dfed7
merge
tobyzl2 Jun 2, 2025
fb23ef8
conversion changes
tobyzl2 Jun 3, 2025
d7d1135
adding data prep sharding
tobyzl2 Jun 3, 2025
012a636
faster mel sepc
tobyzl2 Jun 6, 2025
c664444
adding num audio to config
tobyzl2 Jun 12, 2025
ba73939
audio encoder padding updates
tobyzl2 Jun 12, 2025
5667a0a
configurable max pad
tobyzl2 Jun 12, 2025
9f68a5e
small fix
tobyzl2 Jun 16, 2025
c286f8d
debugging updates
tobyzl2 Jun 18, 2025
eb39e7e
working 5b changes
tobyzl2 Jun 23, 2025
a53c89a
small fixes
tobyzl2 Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class GPTBatch:
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None
audio: list[torch.Tensor] | None = None
audio_positions: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
Expand All @@ -54,16 +56,34 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
batch_images.append([])
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None:
if sample.image_positions is not None and len(sample.image_positions) > 0:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append([])

has_audio = False
batch_audio = []
for sample in batch:
if sample.audio is not None and sample.audio_positions is not None:
batch_audio.append([torch.from_numpy(audio) for audio in sample.audio])
has_audio = True
else:
batch_audio.append(None)
batch_audio_positions = []
for sample in batch:
if sample.audio_positions is not None:
batch_audio_positions.append(torch.from_numpy(sample.audio_positions))
else:
batch_audio_positions.append([])

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
images=batch_images if has_images else None,
image_positions=batch_image_positions if has_images else None,
audio=batch_audio if has_audio else None,
audio_positions=batch_audio_positions if has_audio else None,
)


Expand Down
10 changes: 10 additions & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ class GPTSamplingParameters(SamplingParameters):
cross_document_attention: bool = True
patch_size: int | None = None
image_size: int | None = None
aud_downsampling_k: int | None = None
aud_padding_duration: int | None = None
aud_sampling_rate: int | None = None
image_break_token: int | None = None
image_end_token: int | None = None
audio_start_token: int | None = None
audio_end_token: int | None = None
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand Down Expand Up @@ -204,6 +209,11 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)
num_audio: int | None = Field(
default=None,
desc="Expected number of audio in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
Expand Down
8 changes: 6 additions & 2 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
doc_sizes, im_sizes = self._dataset.get_document_sizes()
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else []
doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes()
return (
doc_sizes[self._begin : self._end],
im_sizes[self._begin : self._end] if im_sizes else [],
aud_sizes[self._begin : self._end] if aud_sizes else [],
)

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)
Expand Down
Loading