Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/evalharness/run_bsevalharness_prefix.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/bin/bash
#SBATCH --job-name=run_evalharness-tr13f-6b3
#SBATCH --partition=gpu_p5
#SBATCH --constraint=a100
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=8 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:1 # number of gpus
#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=six@a100
#SBATCH --reservation=hug

set -x -e

source $six_ALL_CCFRWORK/start-muennighofflmeval

echo "START TIME: $(date)"

# a unique identifier for the current eval ideally correspnding to the modelname
VARIANT="tr13f-prefix"


CHECKPOINT_PATH=/gpfsscratch/rech/six/commun/checkpoints/tr13f-6B3-ml-t0/checkpoints/prefix/global_step3100
MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRSCRATCH/commun/experiments/muennighoff/megdsbslmeval/Megatron-DeepSpeed
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models
export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasetseval
export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules
export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics
export TOKENIZERS_PARALLELISM=false

cd $MEGATRON_DEEPSPEED_REPO

TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles

PP_SIZE=1
TP_SIZE=1
SEQ_LEN=2048

# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS
# make as big as it can fit into gpu w/o OOM, but not too close to 100%
EVAL_MICRO_BATCH_SIZE=1

#dummy arguments to make megatron happy.
MEGATRON_REQUIRED_ARGS=" \
--num-layers -1 \
--hidden-size -1 \
--num-attention-heads -1 \
--seq-length -1 \
--max-position-embeddings -1 \
"


ZERO_STAGE=0

config_json="./ds_config.json"

# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > $config_json
{
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": 1,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": false
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT


CMD="./tasks/eval_harness/evaluate_bsevalharness_prefix.py \
--load $CHECKPOINT_PATH \
--results_path $VARIANT-results.json \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
--micro-batch-size $EVAL_MICRO_BATCH_SIZE \
--no-load-optim \
--no-load-rng \
--eval_fp32 \
--inference \
--seq-length $SEQ_LEN \
--task_list copa \
--prefix \
--deepspeed \
--deepspeed_config ds_config.json \
--intermed_results \
--adaptive_seq_len \
--micro_bs_multiplier 8 \
$MEGATRON_REQUIRED_ARGS \
"

GPUS_PER_NODE=1
NNODES=$SLURM_NNODES
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

export CUDA_LAUNCH_BLOCKING=1

echo $LAUNCHER $CMD

export PYTHONPATH=$MEGATRON_DEEPSPEED_REPO

$LAUNCHER $CMD 2>&1 | tee $VARIANT-eval-harness.log
121 changes: 121 additions & 0 deletions examples/evalharness/run_evalharness_prefix.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/bin/bash
#SBATCH --job-name=run_evalharness-tr13f-6B3-prefix
#SBATCH --partition=gpu_p5
#SBATCH --constraint=a100
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=8 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:1 # number of gpus
#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=six@a100
#SBATCH --reservation=hug

set -x -e

source $six_ALL_CCFRWORK/start-py38-pt111

echo "START TIME: $(date)"

# a unique identifier for the current eval ideally correspnding to the modelname
VARIANT="tr13f-6B3-prefix"


CHECKPOINT_PATH=/gpfsscratch/rech/six/commun/checkpoints/tr13f-6B3-ml-t0/checkpoints/prefix/global_step3100
MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRSCRATCH/commun/experiments/muennighoff/megdsbslmeval/Megatron-DeepSpeed
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models
export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasets
export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules
export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics

cd $MEGATRON_DEEPSPEED_REPO

TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles

PP_SIZE=1
TP_SIZE=1
SEQ_LEN=2048

# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS
# make as big as it can fit into gpu w/o OOM, but not too close to 100%
EVAL_MICRO_BATCH_SIZE=1

#dummy arguments to make megatron happy.
MEGATRON_REQUIRED_ARGS=" \
--num-layers -1 \
--hidden-size -1 \
--num-attention-heads -1 \
--seq-length -1 \
--max-position-embeddings -1 \
"


ZERO_STAGE=0

config_json="./ds_config.json"

# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > $config_json
{
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": 1,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": false
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT


CMD="./tasks/eval_harness/evaluate_evalharness_prefix.py \
--load $CHECKPOINT_PATH \
--results_path $VARIANT-results.json \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
--micro-batch-size $EVAL_MICRO_BATCH_SIZE \
--no-load-optim \
--no-load-rng \
--eval_fp32 \
--inference \
--seq-length $SEQ_LEN \
--task_list copa \
--prefix \
--deepspeed \
--deepspeed_config ds_config.json \
--intermed_results \
--adaptive_seq_len \
--micro_bs_multiplier 8 \
$MEGATRON_REQUIRED_ARGS \
"

GPUS_PER_NODE=1
NNODES=$SLURM_NNODES
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

export CUDA_LAUNCH_BLOCKING=1

echo $LAUNCHER $CMD

export PYTHONPATH=$MEGATRON_DEEPSPEED_REPO

$LAUNCHER $CMD 2>&1 | tee $VARIANT-eval-harness.log
38 changes: 34 additions & 4 deletions tasks/eval_harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.nn.functional as F

from lm_eval.tasks import ALL_TASKS
from pretrain_gpt import model_provider
import numpy as np

import torch
Expand Down Expand Up @@ -154,7 +153,9 @@ def _collate(x):
contlens.append(cont)
inplens.append(inplen)

logits = self._model_call(torch.cat(inps, dim=0))
# contlens stores contencs not contlens, but not changing the variable names for consistency
prefix_lens = torch.tensor([ilen - len(ctoks) for ilen, ctoks in zip(inplens, contlens)])[:, None]
logits = self._model_call(torch.cat(inps, dim=0), prefix_lens=prefix_lens)
res_len += len(chunk)
if logits is not None:
if self.args.offloadearly:
Expand Down Expand Up @@ -186,8 +187,13 @@ def _collate(x):
return reord.get_original(res)

def create_model_inputs(self, tokens):

args = get_args()

if args.prefix:
assert len(tokens) == 2
tokens, prefix_lens = tokens

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
self.EOT_TOKEN_ID,
Expand All @@ -196,10 +202,19 @@ def create_model_inputs(self, tokens):
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=False)

if args.prefix:
assert len(prefix_lens) == attention_mask.shape[0] == tokens.shape[0]
for i, prefix_len in enumerate(prefix_lens):
assert prefix_len <= attention_mask.shape[-1]
# Attention is paid to False (True ones are masked out)
attention_mask[i, :, :prefix_len, :prefix_len] = False



return (tokens, position_ids, attention_mask), (tokens, loss_mask)

def _model_call(self, inps):
def _model_call(self, inps, prefix_lens=None):
args = get_args()

if args.deepspeed:
Expand All @@ -208,7 +223,15 @@ def _model_call(self, inps):
new_size = ((len(inps) + args.micro_batch_size-1) // args.micro_batch_size) * args.micro_batch_size
padded = F.pad(inps, (0, 0, 0, new_size-len(inps)), value = 0)
# dummy data iterator for pipelining.
data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size)))
if args.prefix:
assert prefix_lens.shape == (padded.shape[0], 1)
data_iterator = [(torch.stack(inp), torch.stack(pfx)) for inp, pfx in zip(
utils.chunks(padded, args.micro_batch_size),
utils.chunks(prefix_lens, args.micro_batch_size),
)
]
else:
data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size)))
self.model.micro_batches = len(data_iterator)

if self.adaptive_seq_len:
Expand Down Expand Up @@ -348,6 +371,12 @@ def load_ds_checkpoint_and_setup_megatron(args):

# print final arguments.
_print_args(args)

if args.prefix:
from finetune_t0_non_causal_decoder import model_provider
else:
from pretrain_gpt import model_provider

if args.deepspeed:

# Hack #3:
Expand Down Expand Up @@ -393,6 +422,7 @@ def tasks_args(parser):
group.add_argument('--adaptive_seq_len', default = False, action='store_true',
help='Should the sequence length be adapted to the batch during evaluation, if in fp16 the results will be slightly different due to numerical errors but greatly speed up evaluation.')
group.add_argument('--eval_fp32', default = False, action='store_true', help='Should the evaluation run in fp32')
group.add_argument('--prefix', default=False, action='store_true', help='Prefix LM - Bidirectional att over input')
group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task')
group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation')
group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel')
Expand Down
Loading