Skip to content

Conversation

@sanketpurandare
Copy link
Contributor

@sanketpurandare sanketpurandare commented Nov 12, 2025

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
sanketpurandare added a commit that referenced this pull request Nov 12, 2025
ghstack-source-id: b2f75c8
Pull Request resolved: #247
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 12, 2025
sanketpurandare added a commit that referenced this pull request Nov 12, 2025
ghstack-source-id: 56b4b85
Pull Request resolved: #247
sanketpurandare added a commit that referenced this pull request Nov 12, 2025
ghstack-source-id: ad07d83
Pull Request resolved: #247

def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
def _export(
model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment to the docstring explaining what model_wrapper is for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comprehensive doc string

model_wrapper: Callable
if self.loss_fn is not None:

def model_with_loss(inputs, target) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we call 'model_wrapper' in export, we just pass it *inputs which apparently is expanding to inputs, target? that part is a little confusing for ux

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrote some comments and have a consistent format now.

self._accumulate_stage_grads_and_clear_states(stage)
if stage.is_last and has_targets_and_loss:
losses = kwargs["losses"]
losses.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing, i am expecting that we take the losses from kwargs and use them, why do we immediately clear them.. and replace with 'schedule internal losses' - what are those...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in the last stage of PP, the loss is actually appended in an internally maintained list called as schedule._internal_losses. At the end of the pipeline step, the user provided list of losses is extended with schedule._internal_losses. Yeah, the losses.clear() is wrong and removed it, we shouldn't manage the user provided losses, we should just extend it.


# Tracing input functions
tracing_input_fn = make_input_fn(spmd_batch_size, "tokens", device)
tracing_input_fn_after_first_stage = make_input_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the last stage output different shape than the embeddings for middle layers?

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for that I have

shape_inference_output_fn_last_stage = ...

)
return target_fn

# Tracing input functions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me, why do we have to make our own tracing functions? are we not using shape inference inside pipelining runtime? oh- autop needs this, i guess.

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We need the tracing functions for AutoP.
  2. We need the shape_inference functions for PP to run with fake_pg.
  3. We need the runtime functions for generating inputs/targets for the actual run.

print_output=False, include_stride=True, include_device=True
),
)
print(gm.graph)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:( use tlparse to get your graphs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry!, this was for my own debugging

sanketpurandare added a commit that referenced this pull request Nov 13, 2025
ghstack-source-id: 3c83666
Pull Request resolved: #247

def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
def _export(
model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any, ...]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: model_wrapper can be optional? if not provided, just use the model itself as the tracing entrypoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model wrapper is never None or Optional. Even if we don't use loss function we do construct a model wrapper. I feel this would be a good api even for future use when we decide to use an optimizer.

sanketpurandare added a commit that referenced this pull request Nov 13, 2025
ghstack-source-id: 53b1817
Pull Request resolved: #247
sanketpurandare added a commit that referenced this pull request Nov 14, 2025
ghstack-source-id: 4bc2ada
Pull Request resolved: #247
@sanketpurandare sanketpurandare changed the base branch from gh/sanketpurandare/1/base to main November 14, 2025 01:47
@sanketpurandare sanketpurandare merged commit a8d46ea into main Nov 14, 2025
4 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants