From 7daf10c8c07f6ba89156944f8c80c04816307034 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 00:03:26 -0800 Subject: [PATCH 1/3] Update (base update) [ghstack-poisoned] From f6a66d9f859a872ca925d5b6a54fe6aa3b18c9a0 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 00:03:26 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchtitan/train.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 4d3ed12e8e..365a3700ac 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -410,26 +410,37 @@ def batch_generator( yield input_dict, labels - def forward_backward_step( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - + def post_dataloading_processing( + self, input_dict: dict[str, torch.Tensor], label: torch.Tensor + ) -> tuple[ + dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor], dict[str, Any] + ]: + """Post processing after data loading.""" inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} # For arguments, like attention_masks, we have to put them in a separate # dict as extra_inputs are not forwarded to other stages in PP, but # extra_kwargs are. - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if getattr(self.model_args, "use_flex_attn", False): - extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) + return inputs, label, extra_inputs, extra_kwargs + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, label, extra_inputs, extra_kwargs = self.post_dataloading_processing( + input_dict, labels + ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( From 5ede5b6756ec8a3fd28a17dcc4fe22c8d62d154e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 6 Nov 2025 23:33:59 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchtitan/train.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 727bffd660..2b3d165fff 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -413,7 +413,39 @@ def batch_generator( def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: - """Post processing after data loading.""" + """ + Post-processing hook after data loading and before model forward pass. + + This method processes the raw data from the dataloader and prepares it for + the model's forward pass. It separates the main input tensor from auxiliary + inputs and constructs additional keyword arguments (e.g., attention masks). + + This method can be overridden in subclasses to customize data processing + for different training strategies (e.g., converting tensors to DTensors, + applying custom transformations, etc.). + + Args: + input_dict: Dictionary containing tensors from the dataloader. Must + contain an "input" key with the main input tensor. May contain + additional keys for auxiliary inputs (e.g., position ids). + labels: Target labels for the batch. + + Returns: + A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + - inputs: Main input tensor extracted from input_dict["input"]. + - labels: Target labels (unchanged from input parameter). + - extra_inputs: Dict of auxiliary input tensors (all keys except + "input" from input_dict). These are passed to the model forward + but are NOT forwarded across pipeline parallel stages. + - extra_kwargs: Dict of additional keyword arguments for model forward. + These ARE forwarded across pipeline parallel stages. Contains + attention_masks if flex attention is enabled. + + Note: + The distinction between extra_inputs and extra_kwargs is important for + pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, + while extra_inputs are only available to the first stage. + """ inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} # For arguments, like attention_masks, we have to put them in a separate