diff --git a/compare_tp_weights.py b/compare_tp_weights.py new file mode 100644 index 000000000..69675e64c --- /dev/null +++ b/compare_tp_weights.py @@ -0,0 +1,85 @@ + +# usage: +# python compare_tp_weights.py input_layernorm.weight 40 2 . + +# input_layernorm.weight +# input_layernorm.bias +# post_attention_layernorm.weight +# post_attention_layernorm.bias + +# one liner for just 2 weights comparison +# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' input_layernorm.weight layer_03-model_00-model_states.pt layer_03-model_01-model_states.pt + +# 13B +# cd /gpfsdsstore/projects/rech/six/commun/checkpoints/tr1-13B/tr1-13B-with-optim/global_step168000 +# python ~/compare_tp_weights.py input_layernorm.weight 40 2 . + +# 104B +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/checkpoints/emb-norm/global_step16800 +# +# python ~/compare_tp_weights.py input_layernorm.weight 64 4 . > ~/104B.input_layernorm.weight.txt +# python ~/compare_tp_weights.py post_attention_layernorm.weight 64 4 . > ~/104B.post_attention_layernorm.weight.txt +# python ~/compare_tp_weights.py input_layernorm.bias 64 4 . > ~/104B.input_layernorm.bias.txt +# python ~/compare_tp_weights.py post_attention_layernorm.bias 64 4 . > ~/104B.post_attention_layernorm.bias.txt + +# other 104B checkpoints: + +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/to-back-up/tr8b-104B/checkpoints/cl-exp-02/global_step10500 +# mismatched 68 +# +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8-104B-wide/experiment11/global_step15660 +# mismatched +# +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8-104B-wide/experiment06/global_step5100 +# python ~/compare_tp_weights.py input_layernorm.weight 32 4 +# **all matched** +# +# python ~/compare_tp_weights.py post_attention_layernorm.weight 32 4 +# not matched + + + +# # 104B/176B embed-norm check +# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_00-model_states.pt layer_01-model_01-model_states.pt +# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_01-model_states.pt layer_01-model_02-model_states.pt +# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_02-model_states.pt layer_01-model_03-model_states.pt + +# same on cpu +python -c 'import torch, sys; k=sys.argv[1]; a=torch.load(sys.argv[2], map_location=torch.device("cpu"));b=torch.load(sys.argv[3], map_location=torch.device("cpu")); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_00-model_states.pt layer_01-model_01-model_states.pt + +# # 176B +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr11-176B-ml/checkpoints/main/global_step16400 +# python ~/compare_tp_weights.py input_layernorm.weight 70 4 . > ~/176B.input_layernorm.weight.txt +# python ~/compare_tp_weights.py post_attention_layernorm.weight 70 4 . > ~/176B.post_attention_layernorm.weight.txt +# python ~/compare_tp_weights.py input_layernorm.bias 70 4 . > ~/176B.input_layernorm.bias.txt +# python ~/compare_tp_weights.py post_attention_layernorm.bias 70 4 . > ~/176B.post_attention_layernorm.bias.txt + + +import torch, sys + + + +key, nlayers, tp_size, checkpoint_dir = sys.argv[1:5] + +print(f"checking key={key}") +matched, mismatched = 0, 0 +for layer_id in range(int(nlayers)): + for tp in range(int(tp_size)-1): + f1 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp:02d}-model_states.pt" + f2 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp+1:02d}-model_states.pt" + c1 = torch.load(f1) + c2 = torch.load(f2) + # print(f1) + # print(f2) + header = f"layer_id={layer_id}: {tp}-{tp+1}" + try: + torch.testing.assert_close(c1[key], c2[key], rtol=0.0, atol=0.0, check_device=False) + print(f"✓ {header}") + matched += 1 + except: + print(f"✗ {header}") + mismatched += 1 + #raise + +print(f"Matched : {matched}") +print(f"Mismatched: {mismatched}") diff --git a/compare_tp_weights_cpu.py b/compare_tp_weights_cpu.py new file mode 100644 index 000000000..7c07bcabb --- /dev/null +++ b/compare_tp_weights_cpu.py @@ -0,0 +1,42 @@ + +# usage: +# python compare_tp_weights.py input_layernorm.weight 40 2 . + + +# 13B +# cd /gpfsdsstore/projects/rech/six/commun/checkpoints/tr1-13B/tr1-13B-with-optim/global_step168000 +# python ~/compare_tp_weights.py input_layernorm.weight 40 2 . + +# 104B +# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/checkpoints/emb-norm/global_step16800 +# python ~/compare_tp_weights.py input_layernorm.weight 64 4 . + + +import torch, sys + + + +key, nlayers, tp_size, checkpoint_dir = sys.argv[1:5] + +print(f"checking key={key}") +matched, mismatched = 0, 0 +for layer_id in range(int(nlayers)): + for tp in range(int(tp_size)-1): + f1 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp:02d}-model_states.pt" + f2 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp+1:02d}-model_states.pt" + c1 = torch.load(f1, map_location=torch.device('cpu')) + c2 = torch.load(f2, map_location=torch.device('cpu')) + # print(f1) + # print(f2) + header = f"layer_id={layer_id}: {tp}-{tp+1}" + try: + torch.testing.assert_close(c1[key], c2[key], rtol=0.0, atol=0.0, check_device=False) + print(f"✓ {header}") + matched += 1 + except: + print(f"✗ {header}") + mismatched += 1 + #raise + +print(f"Matched : {matched}") +print(f"Mismatched: {mismatched}") diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index c8109b3d2..3edf9ea8b 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -54,7 +54,8 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, - num_workers=num_workers, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), pin_memory=True) class MegatronPretrainingSampler: diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 563566b70..391e19978 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -21,6 +21,7 @@ from packaging import version import torch +from megatron import mpu from torch import nn from torch.nn.parameter import Parameter import torch.nn.functional as F @@ -37,7 +38,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): - ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() @@ -96,7 +96,29 @@ def reset_parameters(self): init.zeros_(self.bias) + def forward_old(self, input): +# weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())] +# torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group()) +# biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())] +# torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group()) +# if any(torch.any(weight != self.weight) for weight in weights): +# if mpu.get_tensor_model_parallel_rank() == 0: +# print("Weight sync failed") +# print(weights) +# if any(torch.any(bias != self.bias) for bias in biases): +# if mpu.get_tensor_model_parallel_rank() == 0: +# print("Bias sync failed") +# print(biases) + + return FusedLayerNormAffineFunction.apply( + input, self.weight, self.bias, self.normalized_shape,self.eps) + + def forward(self, input): + + torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + if self.use_meg_ds_fused_layer_norm: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 821d9acfe..6056f94f6 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -82,6 +82,7 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_): + # TODO: we need to assert that the input_ are all the same within a group return input_ @staticmethod @@ -102,6 +103,7 @@ def forward(ctx, input_): @staticmethod def backward(ctx, grad_output): + # TODO: we need to assert that the grad_output are all the same within a group return grad_output diff --git a/megatron/training.py b/megatron/training.py index bbf6623e3..8e2b4a6de 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -367,6 +367,93 @@ def get_learning_rate_scheduler(optimizer): return lr_scheduler +def sync_layer_norm(n, p): + + rank = torch.distributed.get_rank() + + print(f'rank {rank} processing {n}') + + #return + + # # Here is how you can access fp32 version of the bf16 param and fp32 optim states + # # + # # Note that there is an all_reduce called on all dp ranks when `get_full_hp_param` is called - + # # so it's not free + # # + # # a. fp32 param + # fp32_param = p.get_full_hp_param() + # torch.set_printoptions(sci_mode=False, precision=6) + # print(f'rank {rank} bf16 = {p}') + # print(f'rank {rank} fp32 = {fp32_param}') + # torch.testing.assert_close(p, fp32_param, rtol=4e-3, atol=0, check_dtype=False) + + # # b. fp32 optim states + # for key in ['exp_avg', 'exp_avg_sq']: + # full_optim_state = p.get_full_hp_param(optim_state_key=key) + # print(f'rank {rank} full optim state {key} = {full_optim_state}') + + # 1. bf16 + #print(f'rank {rank} before reduce p = {p}') + torch.distributed.all_reduce(p, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + #print(f'rank {rank} after reduce p = {p}') + + + if p._hp_mapping is not None: + #print(f'rank {rank} fixing hp for input_layernorm') + #p._hp_mapping.update_hp() + + # 2. fp32 + hp = p._hp_mapping.hp_fragment + torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + + # 3. optim states + for key in ['exp_avg', 'exp_avg_sq']: + optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key) + #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}') + torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}') + + +def sync_all_layer_norms(model): + # syncs weight+bias for each of the following layer norms (via averaging across TP ranks) + # 1. word embedding front word_embeddings.norm + # 2. transformer block input_layernorm x 70 + # 3. transformer block post_attention_layernorm x 70 + # 4. word embedding head - I think it's just weight + bias w/o a proper name in the last layer file layer_0X-model_0X-model_states.pt, see: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/affff3d2927864c6948075700c672971782441f4/megatron/model/gpt_model.py#L267 + + import re + layer_norms_params_end_with = [ + "word_embeddings.norm.weight", "word_embeddings.norm.bias", + "input_layernorm.weight", "input_layernorm.bias", + "post_attention_layernorm.weight", "post_attention_layernorm.bias", + "self_attention.dense.bias", "mlp.dense_4h_to_h.bias", + ] + + for n,p in model.named_parameters(): + #print(n) + # XXX: would be much simpler to re-do this logic to traverse children modules and act on isinstance of MixedFusedLayerNorm instead + # 1. first easy to identify layer norm params as they have a unique prefix each + for end in layer_norms_params_end_with: + if n.endswith(end): + sync_layer_norm(n, p) + + # 2. now the last layer norm that has no prefix + # hack: (\d\d): MixedFusedLayerNorm() is hanging there w/o any prefix name, so need to match something like: + # /^6.weight$/ or /^6.bias$/ + if mpu.is_pipeline_last_stage() and re.match(r'^\d+\.(weight|bias)$', n): + sync_layer_norm(n, p) + +def sync_all_torch_random_state(): + torch_rng_state = torch.get_rng_state().cuda() + # We use rank 1 as source of truth and sed the new + torch.distributed.broadcast( + torch_rng_state, + src=mpu.get_tensor_model_parallel_src_rank() + 1, + group=mpu.get_tensor_model_parallel_group() + ) + torch.set_rng_state(torch_rng_state.cpu()) + + def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() @@ -416,9 +503,17 @@ def setup_model_and_optimizer(model_provider_func): torch.distributed.barrier() timers('load-checkpoint').stop() timers.log(['load-checkpoint']) + print_rank_0(f'module = {model[0]}') + + # turn on to enable layer norm syncing + if 1: + sync_all_layer_norms(model[0].module) + sync_all_torch_random_state() else: args.iteration = 0 + torch.distributed.barrier() + # We only support local DDP with multiple micro-batches. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: assert args.DDP_impl == 'local' diff --git a/requirements.txt b/requirements.txt index da76b5e44..47e11bf04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ pybind11 regex six tensorboard -torch>=1.7 +torch>=1.11 transformers -DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git +# for now using this branch for bf16 work +DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git@olruwase/bf16-updates +#DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git # versions from HF transformers black==21.4b0 isort>=5.5.4 diff --git a/run_bf16.sh b/run_bf16.sh index fd3a48398..f4295cc30 100755 --- a/run_bf16.sh +++ b/run_bf16.sh @@ -36,26 +36,6 @@ ZERO_STAGE=0 #GLOBAL_BATCH=128 #WORKER_STR="-i worker-0" - -TP=1 -PP=1 -DP=2 -WORLD_SIZE=$((TP*PP*DP)) -HIDDEN=1024 -LAYERS=24 -SEQ=1024 -GLOBAL_BATCH=1 -WORKER_STR="" - -MICRO_BATCH=1 - -LR=6.0e-4 -MIN_LR=6.0e-5 -DTYPE="bf16" -EXP_DIR=${HOME}/experiments/results/bf16 -LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3" -mkdir -p $LOG_DIR - while [[ $# -gt 0 ]] do key="$1" @@ -66,30 +46,69 @@ case $key in ;; -z|--zero-stage) ZERO_STAGE=$2; + shift shift ;; *) - echo "Unknown argument(s)" - usage + echo "Unknown argument(s): $key" exit 1 shift ;; esac done +TP=4 +PP=1 +DP=2 +WORLD_SIZE=$((TP*PP*DP)) + +HIDDEN=1024 +LAYERS=24 +NHEADS=32 +SEQ=1024 + +#LAYERS=2 +#HIDDEN=8 +#NHEADS=2 +#SEQ=8 + +GLOBAL_BATCH=64 +WORKER_STR="" +EXIT_ITERS=10 +TRAIN_SAMPLES=1000000 +MICRO_BATCH=32 +LR=1.0e-1 +MIN_LR=1.0e-1 +DTYPE="bf16" +RUN_VERSION=1 +EXP_DIR=${HOME}/experiments/results/bf16 +RUN_TAG="tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_train_${EXIT_ITERS}_v${RUN_VERSION}" +LOG_DIR="${EXP_DIR}/tensorboard/${RUN_TAG}" +mkdir -p $LOG_DIR +export BIT16_DUMP_FILE="${EXP_DIR}/${RUN_TAG}.txt" +CHECKPOINT_DIR="./checkpoints/${DTYPE}_z${ZERO_STAGE}_tp${TP}_pp${PP}_dp${DP}_nl${LAYERS}_exit_${EXIT_ITERS}_v${RUN_VERSION}" options=" \ --tensor-model-parallel-size $TP \ --pipeline-model-parallel-size $PP \ --num-layers $LAYERS \ --hidden-size $HIDDEN \ - --num-attention-heads 32 \ + --num-attention-heads ${NHEADS} \ --seq-length $SEQ \ - --loss-scale 12 \ --max-position-embeddings $SEQ \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ - --train-iters 1000 \ + --optimizer adam \ + --adam-eps 1e-8 \ + --lr-warmup-samples 5 \ + --min-lr 1e-6 \ + --lr-decay-style cosine \ + --lr-decay-samples 12 \ + --override-lr-scheduler \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + --embed-layernorm \ + --partition-activations \ --lr $LR \ --min-lr $MIN_LR \ --lr-decay-style cosine \ @@ -100,17 +119,21 @@ options=" \ --vocab-file ${VOCAB_PATH} \ --merge-file ${MERGE_PATH} \ --save-interval 10000 \ - --split 98,2,0 \ - --clip-grad 1.0 \ --weight-decay 0.1 \ --adam-beta1 0.9 \ --adam-beta2 0.95 \ --init-method-std 0.006 \ --${DTYPE} \ --checkpoint-activations \ - --exit-interval 10000 \ + --train-samples ${TRAIN_SAMPLES} \ + --exit-interval ${EXIT_ITERS} \ + --seed 42 \ + --load ${CHECKPOINT_DIR} \ + --save ${CHECKPOINT_DIR} \ --tensorboard-dir $LOG_DIR " +# --split 10,0,0 \ +# --rampup-batch-size 2 2 1_000 \ if [[ ${USE_DEEPSPEED} -eq 1 ]]; then @@ -155,7 +178,8 @@ WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" #WORKER_STR="-i worker-0:0,1,2,3" #run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" #run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" -run_cmd="deepspeed --master_port 29700 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" + +run_cmd="deepspeed --master_port 29600 $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" echo ${run_cmd} diff --git a/tests/ds_config_bf16.json b/tests/ds_config_bf16.json new file mode 100644 index 000000000..1f02566c9 --- /dev/null +++ b/tests/ds_config_bf16.json @@ -0,0 +1,13 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 16, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "bf16": { + "enabled": true + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 25921c12a..ed383e17a 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -293,6 +293,5 @@ def test_tokenizer_raise_error_make_vocab_size_divisible_by(self): self.assertEqual(str(exc_info.value), "5121 is not divisible by 128") - if __name__ == '__main__': unittest.main() diff --git a/tests/test_training.py b/tests/test_training.py index c77cb9af2..b65a051e5 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -20,6 +20,8 @@ import re import unittest from pathlib import Path + +import torch from parameterized import parameterized from megatron.testing_utils import ( @@ -31,7 +33,7 @@ require_bnb_non_decorator, require_deepspeed, require_torch_gpu, - set_seed + set_seed, torch_assert_equal ) set_seed(42) @@ -50,7 +52,7 @@ def get_3d_dimensions(): dp_size = 2 pp_size = 2 tp_size = 2 - if num_gpus >= 4: + elif num_gpus >= 4: dp_size = 1 pp_size = 2 tp_size = 2 @@ -592,3 +594,115 @@ def test_skip_train_iteration(self): train_iterations = range(1,10) for i in train_iterations: self.assertTrue(f"iteration {i:8d}/" in cs.out) + + @parameterized.expand(["bf16", "fp16"]) + def test_layer_norm_consistent(self, variation): + src_dir = self.src_dir + output_dir = self.get_auto_remove_tmp_dir() + num_gpus = 2 + seq_len = 128 + data_dir = f"{self.data_dir}/gpt2" + args = f""" + --tensor-model-parallel-size {2} + --pipeline-model-parallel-size {1} + --distributed-backend nccl + + --log-interval 1 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --partition-activations + --exit-interval {20} + + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab-file {data_dir}/gpt2-tiny-vocab.json + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --seq-length {seq_len} + --max-position-embeddings 1024 + --micro-batch-size 2 + --global-batch-size 16 + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-1 + --clip-grad 1.0 + --weight-decay 1e-1 + --embed-layernorm + + --log-level debug + --log-level-replica info + + --rampup-batch-size 2 2 200 + --train-samples 200 + + --position-embedding-type alibi + """.split() + + ds_args = f""" + --deepspeed + --deepspeed-activation-checkpointing + """.split() + + if variation == "bf16": + args.append("--bf16") + ds_args += [ + "--zero-stage", "0", + "--deepspeed_config", f"{self.test_file_dir_str}/ds_config_bf16.json" + ] + elif variation == "fp16": + args.append("--fp16") + ds_args += [ + "--zero-stage", "1", + "--deepspeed_config", f"{self.test_file_dir_str}/ds_config.json" + ] + + # args, ds_args, num_gpus = self.get_variation_config("base", output_dir, n_samples=200) + + script = [f"{src_dir}/pretrain_gpt.py"] + launcher = get_launcher(num_gpus) + cmd = launcher + script + args + ds_args + # keep for quick debug + print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + checkpoints = ["global_step10", "global_step20"] + + # Check transformer layer norm + keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False) + + # Check embed layer norm + keys_to_compare = ["word_embeddings.norm.weight"] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False)