-
Notifications
You must be signed in to change notification settings - Fork 8
Enabling ZeroBubbleV schedule in Graph PP #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/sanketpurandare/2/base
Are you sure you want to change the base?
Enabling ZeroBubbleV schedule in Graph PP #250
Conversation
[ghstack-poisoned]
autoparallel/graph_pp_runner.py
Outdated
| assert not any( | ||
| grad is None for grad in grads_to_accumulate | ||
| ), "All grads are None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert not any( | |
| grad is None for grad in grads_to_accumulate | |
| ), "All grads are None" | |
| assert not all( | |
| grad is None for grad in grads_to_accumulate | |
| ), "All grads are None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually shouldn't we keep it as 'any' and change the string to match? or do we not care if some grads are none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should keep as any and change the string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some grads can be None, say forward outputs that don't require gradient, the usual torch.compile fw/bw would have None in the backward graph outputs as required by the custom autograd function API.
It's up to the partitioner/graph pass splitters and the runtime wrapper implementation though. As long as we know which grad belongs to which param.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched it back to any
autoparallel/graph_pp_runner.py
Outdated
| assert num_placeholders == len( | ||
| fw_args | ||
| ), f"Mismatched number of inputs to fwd, {len([n for n in fw_module.graph.nodes if n.op == 'placeholder'])}, {len(fw_args)}" | ||
| ), f"Mismatched number of inputs to fwd: expected {num_placeholders}, got {len(fw_args)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no longer needed, we made the change upstream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code pointer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the args len check
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):