@@ -18,29 +18,36 @@ settings:
1818 checkpointing_interval_in_steps : 32
1919 evaluation_interval_in_steps : 32
2020 consistency_enforcement :
21- enforce_tokens_per_step_consistency : true
21+ enforce_tokens_per_step_consistency : false
2222 enforce_last_step_logged : false
2323 enforce_last_step_evaluated : false
2424 enforce_last_step_checkpointed : false
2525 step_profile :
2626 gradient_accumulation_steps : 1
2727 local_train_micro_batch_size : 1
2828 sequence_length : 256
29+ dp_degree :
30+ instance_key : dp_degree
31+ pass_type : BY_REFERENCE
2932 training_target :
3033 num_target_tokens :
3134 component_key : number_conversion
3235 variant_key : num_tokens_from_packed_mem_map_dataset_continuous
3336 config :
3437 dataset_path : ${settings.paths.train_dataset_path}
3538 sequence_length : ${settings.step_profile.sequence_length}
36- num_ranks : ${settings.cuda_env.world_size}
39+ dp_degree :
40+ instance_key : dp_degree
41+ pass_type : BY_REFERENCE
3742 local_micro_batch_size : ${settings.step_profile.local_train_micro_batch_size}
3843 gradient_accumulation_steps : ${settings.step_profile.gradient_accumulation_steps}
3944 num_target_steps : # for the batch progress subscriber
4045 component_key : number_conversion
4146 variant_key : num_steps_from_num_tokens
4247 config :
43- num_ranks : ${settings.cuda_env.world_size}
48+ dp_degree :
49+ instance_key : dp_degree
50+ pass_type : BY_REFERENCE
4451 local_micro_batch_size : ${settings.step_profile.local_train_micro_batch_size}
4552 global_num_tokens : ${settings.training_target.num_target_tokens}
4653 sequence_length : ${settings.step_profile.sequence_length}
@@ -172,9 +179,18 @@ device_mesh:
172179 config :
173180 device_type : cuda
174181 data_parallel_replicate_degree : 1
175- data_parallel_shard_degree : ${settings.cuda_env.world_size} # i.e., fully sharded
182+ data_parallel_shard_degree : -1
176183 world_size : ${settings.cuda_env.world_size}
177184
185+ dp_degree :
186+ component_key : number_conversion
187+ variant_key : parallel_degree
188+ config : # get the parallel degree from the device mesh
189+ device_mesh :
190+ instance_key : device_mesh
191+ pass_type : BY_REFERENCE
192+ parallelism_methods : [dp_shard, dp_replicate]
193+
178194app_state :
179195 component_key : app_state
180196 variant_key : raw
@@ -326,17 +342,14 @@ evaluation_subscriber:
326342 directory : wandb_storage
327343 config_file_path : ${settings.config_file_path}
328344
329- # mfu_calculator:
330- # component_key: mfu_calculator
331- # variant_key: gpt2
332- # config:
333- # n_layer: ${model_raw.config.n_layer}
334- # sequence_length: ${settings.step_profile.sequence_length}
335- # n_embd: ${model_raw.config.n_embd}
336- # world_size: ${settings.cuda_env.world_size}
337- # raw_model:
338- # instance_key: model_raw
339- # pass_type: BY_REFERENCE
340- # wrapped_model:
341- # instance_key: initialized_model
342- # pass_type: BY_REFERENCE
345+ mfu_calculator :
346+ component_key : mfu_calculator
347+ variant_key : gpt2
348+ config :
349+ n_layer : ${model_raw.config.n_layer}
350+ sequence_length : ${settings.step_profile.sequence_length}
351+ n_embd : ${model_raw.config.n_embd}
352+ world_size : ${settings.cuda_env.world_size}
353+ wrapped_model :
354+ instance_key : initialized_model
355+ pass_type : BY_REFERENCE
0 commit comments