-
Notifications
You must be signed in to change notification settings - Fork 8
Enable Loss Fn in Graph PP #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
autoparallel/api.py
Outdated
|
|
||
| def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: | ||
| def _export( | ||
| model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment to the docstring explaining what model_wrapper is for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comprehensive doc string
autoparallel/api.py
Outdated
| model_wrapper: Callable | ||
| if self.loss_fn is not None: | ||
|
|
||
| def model_with_loss(inputs, target) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we call 'model_wrapper' in export, we just pass it *inputs which apparently is expanding to inputs, target? that part is a little confusing for ux
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrote some comments and have a consistent format now.
autoparallel/graph_pp_runner.py
Outdated
| self._accumulate_stage_grads_and_clear_states(stage) | ||
| if stage.is_last and has_targets_and_loss: | ||
| losses = kwargs["losses"] | ||
| losses.clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit confusing, i am expecting that we take the losses from kwargs and use them, why do we immediately clear them.. and replace with 'schedule internal losses' - what are those...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in the last stage of PP, the loss is actually appended in an internally maintained list called as schedule._internal_losses. At the end of the pipeline step, the user provided list of losses is extended with schedule._internal_losses. Yeah, the losses.clear() is wrong and removed it, we shouldn't manage the user provided losses, we should just extend it.
examples/example_ds3_pp.py
Outdated
|
|
||
| # Tracing input functions | ||
| tracing_input_fn = make_input_fn(spmd_batch_size, "tokens", device) | ||
| tracing_input_fn_after_first_stage = make_input_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't the last stage output different shape than the embeddings for middle layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah for that I have
shape_inference_output_fn_last_stage = ...
| ) | ||
| return target_fn | ||
|
|
||
| # Tracing input functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remind me, why do we have to make our own tracing functions? are we not using shape inference inside pipelining runtime? oh- autop needs this, i guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We need the tracing functions for AutoP.
- We need the shape_inference functions for PP to run with fake_pg.
- We need the runtime functions for generating inputs/targets for the actual run.
autoparallel/api.py
Outdated
| print_output=False, include_stride=True, include_device=True | ||
| ), | ||
| ) | ||
| print(gm.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:( use tlparse to get your graphs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry!, this was for my own debugging
[ghstack-poisoned]
|
|
||
| def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: | ||
| def _export( | ||
| model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: model_wrapper can be optional? if not provided, just use the model itself as the tracing entrypoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model wrapper is never None or Optional. Even if we don't use loss function we do construct a model wrapper. I feel this would be a good api even for future use when we decide to use an optimizer.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):