-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[trainer] feat: vlm support for sft engine #3729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Vision-Language Models (VLM) in the SFT trainer. The changes are comprehensive, affecting dataset creation, data processing, and the training engine to handle multi-modal inputs. While the overall approach is sound, I've identified a few critical issues related to data processing and model output handling that could lead to runtime errors or incorrect behavior. Please address these points to ensure the stability and correctness of the new VLM functionality.
for conv in messages: | ||
for content in conv["content"]: | ||
for k, v in content.items(): | ||
if v is None: | ||
content.pop(k) | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop for removing None
values from the content
dictionary is incorrect. The break
statement will cause the inner loop to exit after removing only the first key with a None
value. If a content
dictionary contains multiple None
values, the subsequent ones will not be removed, which could lead to unexpected behavior or errors downstream. Additionally, modifying a dictionary while iterating over its items is unsafe and can lead to unpredictable results.
A more robust and Pythonic way to achieve this is to use a list comprehension to rebuild the list of content dictionaries, filtering out any keys with None
values.
for conv in messages:
conv["content"] = [
{k: v for k, v in content.items() if v is not None} for content in conv["content"]
]
if hasattr(output, "last_hidden_state"): | ||
logits = output.last_hidden_state | ||
else: | ||
logits = output.logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential type and shape mismatch here. output.last_hidden_state
typically contains embeddings from the transformer body, with a shape of (batch_size, seq_len, hidden_size)
. However, the variable is named logits
, and it's later used in logprobs_from_logits
, which expects model logits of shape (batch_size, seq_len, vocab_size)
. Assigning hidden states to logits
will likely cause a dimension mismatch error and incorrect calculations downstream. If the intention is to handle models that don't have a language model head applied, the hidden states should be passed through an LM head before being treated as logits.
verl/workers/engine/utils.py
Outdated
if "multi_modal_inputs" in micro_batches[0]: | ||
multi_modal_data = micro_batches[0]["multi_modal_inputs"] | ||
for batch, indexes in zip(micro_batches, batch_idx_list, strict=False): | ||
batch["multi_modal_inputs"] = [multi_modal_data[i] for i in indexes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for distributing multi_modal_inputs
will cause a TypeError
when use_dynamic_bsz
is False
. In that case, batch_idx_list
is None
, and passing None
to zip()
is not allowed. This logic is only applicable when dynamic batch sizing is enabled, as batch_idx_list
is only populated in that scenario. To fix this, this block should be moved inside the if use_dynamic_bsz:
block, right after rearrange_micro_batches
is called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we merge into main
, can we also have some correctness validation in the PR description like training curve in wandb or similar logs? cc: @vermouth1992
@ccclyu Thanks for your advice. I updated training curves. |
@vermouth1992 The FSDP part is finished. The loss are all close but the grad norms are slightly different. Do you think it's ok? |
The grad norm cannot strictly match, but it seems that it will not effect losses. And it's really difficult to find the root cause.
|
What does this PR do?
This PR introduces support for VLM in SFT training.
It builds upon the work in PR #3590 and PR #3589, incorporating their contributions while fixing existing bugs. Due to the significant number of new features added to the main branch since the original PRs were opened, rebasing them became impractical. This new PR serves as a consolidated and up-to-date implementation.
The code is tested through
test_sft_engine_vllm_all.sh
and the result is as follow. (megatron is not tested by now because it's loss is too high.)