Skip to content

Commit 6a53a31

Browse files
author
rrutmann
committed
Merge: main into current branch
1 parent 2c963e3 commit 6a53a31

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

src/modalities/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from modalities.logging_broker.subscriber import MessageSubscriberIF
2121
from modalities.registry.components import COMPONENTS
2222
from modalities.registry.registry import Registry
23-
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_num_parallel_ranks
2423
from modalities.trainer import Trainer
2524
from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0
2625
from modalities.utils.logger_utils import get_logger
@@ -158,7 +157,7 @@ def run(self, components: TrainingComponentsInstantiationModel):
158157
gradient_acc_steps=components.settings.step_profile.gradient_accumulation_steps,
159158
gradient_clipper=components.gradient_clipper,
160159
global_num_tokens_per_train_step=global_num_tokens_per_train_step,
161-
dp_degree=components.settings.step_profile.dp_degree,
160+
device_mesh=components.device_mesh,
162161
mfu_calculator=components.mfu_calculator,
163162
)
164163

src/modalities/trainer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
import torch
66
import torch.distributed as dist
77
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8+
from torch.distributed.device_mesh import DeviceMesh
89
from torch.optim import Optimizer
910
from torch.optim.lr_scheduler import LRScheduler
1011

1112
from modalities.batch import DatasetBatch, EvaluationResultBatch, ResultItem
1213
from modalities.checkpointing.stateful.app_state import AppState
13-
from modalities.config.instantiation_models import MeshDefinition
1414
from modalities.dataloader.dataloader import LLMDataLoader
1515
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
1616
from modalities.logging_broker.publisher import MessagePublisher
1717
from modalities.loss_functions import Loss
1818
from modalities.models.model import model_predict_batch
1919
from modalities.models.parallelism.pipeline_parallelism import Pipeline
20+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_degree
2021
from modalities.running_env.fsdp.reducer import Reducer
2122
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
2223
from modalities.training.training_progress import TrainingProgress
@@ -37,7 +38,7 @@ def __init__(
3738
evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
3839
gradient_acc_steps: int,
3940
global_num_tokens_per_train_step: int,
40-
mesh_definition: MeshDefinition,
41+
device_mesh: DeviceMesh | None,
4142
num_seen_train_steps: int,
4243
global_num_seen_tokens: int,
4344
num_target_steps: int,
@@ -54,7 +55,8 @@ def __init__(
5455
evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher.
5556
gradient_acc_steps (int): Gradient accumulation steps.
5657
global_num_tokens_per_train_step (int): Global number of tokens per train step.
57-
mesh_definition (MeshDefinition): Mesh definition.
58+
dp_degree (int): Data parallelism degree.
59+
pp_degree (int): Pipeline parallelism degree.
5860
num_seen_train_steps (int): Number of seen train steps.
5961
global_num_seen_tokens (int): Global number of seen tokens.
6062
num_target_steps (int): Number of target steps.
@@ -66,12 +68,16 @@ def __init__(
6668
None
6769
"""
6870
self.global_rank = global_rank
69-
self.pp_degree = mesh_definition.pp_degree
71+
if device_mesh is not None:
72+
self.dp_degree = get_parallel_degree(device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD])
73+
self.pp_degree = get_parallel_degree(device_mesh, [ParallelismDegrees.PP])
74+
else:
75+
self.dp_degree = dist.get_world_size()
76+
self.pp_degree = 1
7077
self.progress_publisher = progress_publisher
7178
self.evaluation_result_publisher = evaluation_result_publisher
7279
self.gradient_acc_steps = gradient_acc_steps
7380
self.global_num_tokens_per_train_step = global_num_tokens_per_train_step
74-
self.dp_degree = mesh_definition.dp_degree
7581
self.num_seen_train_steps = num_seen_train_steps
7682
self.num_target_steps = num_target_steps
7783
self.num_target_tokens = num_target_tokens
@@ -287,9 +293,9 @@ def train(
287293
tensor=cumulated_losses,
288294
operation=dist.ReduceOp.SUM,
289295
# 1.) summed batch loss / (num batches * world size)
290-
# 2.) last batch loss / (world size / num_pipeline_parallel_ranks)
296+
# 2.) last batch loss / (world size / pp_degree)
291297
post_processing_fun=lambda t: torch.stack(
292-
[t[0] / t[-1], t[1] / dist.get_world_size() * self.num_pipeline_parallel_ranks]
298+
[t[0] / t[-1], t[1] / dist.get_world_size() * self.pp_degree]
293299
),
294300
)
295301

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def trainer(progress_publisher_mock, gradient_clipper_mock):
213213
global_num_seen_tokens=0,
214214
num_target_tokens=100,
215215
num_target_steps=10,
216-
num_pipeline_parallel_ranks=1,
216+
pp_degree=1,
217217
)
218218

219219

0 commit comments

Comments
 (0)