From b624f9048e73b419d49360d36774e9183876e2aa Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 16:30:01 +0100 Subject: [PATCH 1/4] Support saving and loading 8-bit block weights --- src/petals/bloom/from_pretrained.py | 8 +++++++- src/petals/cli/convert_model.py | 19 +++++++++++++------ src/petals/server/server.py | 4 +++- src/petals/utils/convert_block.py | 9 +-------- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 9f1d12b22..750ae0c1e 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -23,11 +23,12 @@ from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import get_block_size from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.convert_block import replace_8bit_linear logger = get_logger(__name__) CLIENT_BRANCH = "main" -BLOCK_BRANCH_PREFIX = "block_" +BLOCK_BRANCH_PREFIX = "int8_block" def load_pretrained_block( @@ -38,6 +39,8 @@ def load_pretrained_block( use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, + load_in_8bit=False, + device: Optional[Union[str, torch.device]] = None, ) -> WrappedBloomBlock: """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" @@ -49,6 +52,9 @@ def load_pretrained_block( with init_empty_weights(): block = WrappedBloomBlock(config) + if load_in_8bit: + block = replace_8bit_linear(block) + block = block.to(device) state_dict = _load_state_dict( converted_model_name_or_path, diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 95b08e439..eb26df79b 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -15,16 +15,17 @@ logger = get_logger(__name__) -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") - +DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto") def main(): parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub") - parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype") - parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") + parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", + help="Load initial model in this dtype") + parser.add_argument("--output_path", type=str, default="./converted_model", + help="Track output repo to this folder") parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch") parser.add_argument( @@ -41,7 +42,6 @@ def main(): if args.model == "bigscience/bloom" and free_ram_gb < 400: logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free") - assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}" if os.path.exists(args.output_path) and ( len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path) ): @@ -54,8 +54,15 @@ def main(): config.dht_prefix = args.output_repo model = BloomModel.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] + args.model, use_auth_token=args.use_auth_token, revision=args.revision, + torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16", + load_in_8bit=args.torch_dtype == "int8", + device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"} ) + if args.torch_dtype == "int8": + # trigger weight quantization + model = model.cuda() + if args.resize_token_embeddings: logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}") model.resize_token_embeddings(args.resize_token_embeddings) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 4f2a6456c..25f5ac069 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -401,8 +401,10 @@ def create( use_auth_token=use_auth_token, cache_dir=cache_dir, max_disk_space=max_disk_space, + load_in_8bit=load_in_8bit, + device=device, ) - block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) + block = convert_block(block, block_config, tensor_parallel_devices, device, freeze=True) backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype blocks[module_uid] = TransformerBackend( diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b58cd1aeb..25ccbb7d6 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -24,12 +24,10 @@ def convert_block( config: BloomConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, - load_in_8bit: bool, - threshold: float = 6.0, freeze: bool = True, ) -> tp.TensorParallel: """ - Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization + Optimize a transformer block for use in a Petals server and apply tensor parallelism :note: some optimizations will modify the input block in-place! :param block: a single transformer block, either pre-trained or newly initialized @@ -37,8 +35,6 @@ def convert_block( :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) :param output_device: if tensor_parallel_devices is True, output - :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint - :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 ) :param freeze: if True (default), make all module parameters non-trainable :return: a module that acts like the original block, but runs with all specified optimizations @@ -49,9 +45,6 @@ def convert_block( block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) - if load_in_8bit: - block = replace_8bit_linear(block, threshold=threshold) - for shard, device in zip(block.module_shards, block.devices): shard.to(device) From d70019f2b6f55709f6004ee8a202ca7c143087a7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 16:39:00 +0100 Subject: [PATCH 2/4] Fix formatting --- src/petals/bloom/from_pretrained.py | 2 +- src/petals/cli/convert_model.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 750ae0c1e..ddbb81e27 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -22,8 +22,8 @@ from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import get_block_size -from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.convert_block import replace_8bit_linear +from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for logger = get_logger(__name__) diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index eb26df79b..e37db0ac5 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -17,15 +17,16 @@ DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto") + def main(): parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub") - parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", - help="Load initial model in this dtype") - parser.add_argument("--output_path", type=str, default="./converted_model", - help="Track output repo to this folder") + parser.add_argument( + "--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Load initial model in this dtype" + ) + parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch") parser.add_argument( @@ -54,10 +55,12 @@ def main(): config.dht_prefix = args.output_repo model = BloomModel.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision, + args.model, + use_auth_token=args.use_auth_token, + revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16", load_in_8bit=args.torch_dtype == "int8", - device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"} + device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}, ) if args.torch_dtype == "int8": # trigger weight quantization From 556f0fabe08d28ab8e11f9e65820fcc1a7027895 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 16:45:13 +0100 Subject: [PATCH 3/4] Set device_map only for int8 --- src/petals/cli/convert_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index e37db0ac5..4d6e59b6a 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -60,7 +60,7 @@ def main(): revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16", load_in_8bit=args.torch_dtype == "int8", - device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}, + device_map="auto" if args.torch_dtype == "int8" else None, ) if args.torch_dtype == "int8": # trigger weight quantization From a610f4d7445476003710d8eabeb7dbde4f5a4976 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 17:00:10 +0100 Subject: [PATCH 4/4] Remove load_in_8bit from convert_block --- src/petals/server/throughput.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index ac43759d8..2b9c0fe5b 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,7 +13,7 @@ from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import convert_block +from petals.utils.convert_block import convert_block, replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) @@ -149,7 +149,9 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = WrappedBloomBlock(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) + if load_in_8bit: + block = replace_8bit_linear(block) + block = convert_block(block, config, tensor_parallel_devices, device, freeze=True) cache = None elapsed = 0