-
Notifications
You must be signed in to change notification settings - Fork 600
Fix bugs in initial_load_in_hf when enable_weight_tying=true in Qwen3 #1999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add checks for weight tying in state_dict processing
Co-authored-by: Shuhua Yu <[email protected]>
Co-authored-by: Shuhua Yu <[email protected]>
| self.model_args.enable_weight_tying | ||
| and "lm_head.weight" not in hf_state_dict | ||
| ): | ||
| if "model.embed_tokens.weight" in hf_state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this if? shouldn't we assert the existence of embedding?
My guess is that this if was copied from somewhere PP can be enabled, so embedding is on some ranks but not others. But with PP, we'd also require embedding and lm_head to be on the same rank -- o/w how would you be able to load the lm_head weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for this point, added an assertion here.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
…pytorch#1999) Rebased on main to merge this pr: pytorch#1964 --------- Co-authored-by: William <[email protected]> Co-authored-by: Achazwl <[email protected]>

Rebased on main to merge this pr: #1964