55import torch
66import torch .distributed as dist
77from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
8+ from torch .distributed .device_mesh import DeviceMesh
89from torch .optim import Optimizer
910from torch .optim .lr_scheduler import LRScheduler
1011
1112from modalities .batch import DatasetBatch , EvaluationResultBatch , ResultItem
1213from modalities .checkpointing .stateful .app_state import AppState
13- from modalities .config .instantiation_models import MeshDefinition
1414from modalities .dataloader .dataloader import LLMDataLoader
1515from modalities .logging_broker .messages import ExperimentStatus , MessageTypes , ProgressUpdate
1616from modalities .logging_broker .publisher import MessagePublisher
1717from modalities .loss_functions import Loss
1818from modalities .models .model import model_predict_batch
1919from modalities .models .parallelism .pipeline_parallelism import Pipeline
20+ from modalities .running_env .fsdp .device_mesh import ParallelismDegrees , get_parallel_degree
2021from modalities .running_env .fsdp .reducer import Reducer
2122from modalities .training .gradient_clipping .gradient_clipper import GradientClipperIF
2223from modalities .training .training_progress import TrainingProgress
@@ -37,7 +38,7 @@ def __init__(
3738 evaluation_result_publisher : MessagePublisher [EvaluationResultBatch ],
3839 gradient_acc_steps : int ,
3940 global_num_tokens_per_train_step : int ,
40- mesh_definition : MeshDefinition ,
41+ device_mesh : DeviceMesh | None ,
4142 num_seen_train_steps : int ,
4243 global_num_seen_tokens : int ,
4344 num_target_steps : int ,
@@ -54,7 +55,8 @@ def __init__(
5455 evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher.
5556 gradient_acc_steps (int): Gradient accumulation steps.
5657 global_num_tokens_per_train_step (int): Global number of tokens per train step.
57- mesh_definition (MeshDefinition): Mesh definition.
58+ dp_degree (int): Data parallelism degree.
59+ pp_degree (int): Pipeline parallelism degree.
5860 num_seen_train_steps (int): Number of seen train steps.
5961 global_num_seen_tokens (int): Global number of seen tokens.
6062 num_target_steps (int): Number of target steps.
@@ -66,12 +68,16 @@ def __init__(
6668 None
6769 """
6870 self .global_rank = global_rank
69- self .pp_degree = mesh_definition .pp_degree
71+ if device_mesh is not None :
72+ self .dp_degree = get_parallel_degree (device_mesh , [ParallelismDegrees .DP_REPLICATE , ParallelismDegrees .DP_SHARD ])
73+ self .pp_degree = get_parallel_degree (device_mesh , [ParallelismDegrees .PP ])
74+ else :
75+ self .dp_degree = dist .get_world_size ()
76+ self .pp_degree = 1
7077 self .progress_publisher = progress_publisher
7178 self .evaluation_result_publisher = evaluation_result_publisher
7279 self .gradient_acc_steps = gradient_acc_steps
7380 self .global_num_tokens_per_train_step = global_num_tokens_per_train_step
74- self .dp_degree = mesh_definition .dp_degree
7581 self .num_seen_train_steps = num_seen_train_steps
7682 self .num_target_steps = num_target_steps
7783 self .num_target_tokens = num_target_tokens
@@ -287,9 +293,9 @@ def train(
287293 tensor = cumulated_losses ,
288294 operation = dist .ReduceOp .SUM ,
289295 # 1.) summed batch loss / (num batches * world size)
290- # 2.) last batch loss / (world size / num_pipeline_parallel_ranks )
296+ # 2.) last batch loss / (world size / pp_degree )
291297 post_processing_fun = lambda t : torch .stack (
292- [t [0 ] / t [- 1 ], t [1 ] / dist .get_world_size () * self .num_pipeline_parallel_ranks ]
298+ [t [0 ] / t [- 1 ], t [1 ] / dist .get_world_size () * self .pp_degree ]
293299 ),
294300 )
295301
0 commit comments