diff --git a/examples/zch/Readme.md b/examples/zch/Readme.md new file mode 100644 index 000000000..371034bb0 --- /dev/null +++ b/examples/zch/Readme.md @@ -0,0 +1,74 @@ +# Managed Collision Hash Example + +This example demonstrates the usage of managed collision hash feature in TorchRec, which is designed to efficiently handle hash collisions in embedding tables. We include two implementations of the feature: sorted managed collision Hash (MCH) and MPZCH (Multi-Probe Zero Collision Hash). + +## Folder Structure + +``` +managed_collision_hash/ +├── Readme.md # This documentation file +├── __init__.py # Python package marker +├── main.py # Main script to run the benchmark +└── sparse_arch.py # Implementation of the sparse architecture with managed collision +└── zero_collision_hash_tutorial.ipynb # Jupyter notebook for the motivation of zero collision hash and the use of zero collision hash modules in TorchRec +``` + +### Introduction of MPZCH + +Multi-probe Zero Collision Hash (MPZCH) is a technique that can be used to reduce the collision rate for embedding table lookups. For the concept of hash collision and why we need to manage the collision, please refer to the [zero collision hash tutorial](zero_collision_hash_tutorial.ipynb). + +A MPZCH module contains two essential tables: the identity table and the metadata table. +The identity table is used to record the mapping from input hash value to the remapped ID. The value in each identity table slot is an input hash value, and that hash value's remmaped ID is the index of the slot. +The metadata table share the same length as the identity table. The time when a hash value is inserted into a identity table slot is recorded in the same-indexed slot of the metadata table. + +Specifically, MPZCH include the following two steps: +1. **First Probe**: Check if there are available or evictable slots in its identity table. +2. **Second Probe**: Check if the slot for indexed with the input hash value is occupied. If not, directly insert the input hash value into that slot. Otherwise, perform a linear probe to find the next available slot. If all the slots are occupied, find the next evictable slot whose value has stayed in the table for a time longer than a threshold, and replace the expired hash value with the input one. + +The use of MPZCH module `HashZchManagedCollisionModule` are introduced with detailed comments in the [sparse_arch.py](sparse_arch.py) file. + +The module can be configured to use different eviction policies and parameters. + +The detailed function calls are shown in the diagram below +![MPZCH Module Data Flow](docs/mpzch_module_dataflow.png) + +#### Relationship among Important Parameters + +The `HashZchManagedCollisionModule` module has three important parameters for initialization +- `num_embeddings`: the number of embeddings in the embedding table +- `num_buckets`: the number of buckets in the hash table + +The `num_buckets` is used as the minimal sharding unit for the embedding table. Because we are doing linear probe in MPZCH, when resharding the embedding table, we want to avoid separate the remapped index of an input feature ID and its hash value to different ranks. So we make sure they are in the same bucket, and move the whole bucket during resharding. + +## Usage +We also prepare a profiling example of an Sparse Arch implemented with different ZCH techniques. +To run the profiling example with sorted ZCH: + +```bash +python main.py +``` + +To run the profiling example with MPZCH: + +```bash +python main.py --use_mpzch +``` + +You can also specify the `batch_size`, `num_iters`, and `num_embeddings_per_table`: +```bash +python main.py --use_mpzch --batch_size 8 --num_iters 100 --num_embeddings_per_table 1000 +``` + +The example allows you to compare the QPS of embedding operations with sorted ZCH and MPZCH. On our server with A100 GPU, the initial QPS benchmark results with `batch_size=8`, `num_iters=100`, and `num_embeddings_per_table=1000` is presented in the table below: + +| ZCH module | QPS | +| --- | --- | +| sorted ZCH | 1371.6942797862002 | +| MPZCH | 2750.4449443587414 | + +And with `batch_size=1024`, `num_iters=1000`, and `num_embeddings_per_table=1000` is + +| ZCH module | QPS | +| --- | --- | +| sorted ZCH | 263827.54955056956 | +| MPZCH | 551306.9687760604 | diff --git a/examples/zch/__init__.py b/examples/zch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/zch/docs/mpzch_module_dataflow.png b/examples/zch/docs/mpzch_module_dataflow.png new file mode 100644 index 000000000..8ff4ba9e4 Binary files /dev/null and b/examples/zch/docs/mpzch_module_dataflow.png differ diff --git a/examples/zch/main.py b/examples/zch/main.py new file mode 100644 index 000000000..18b114b3b --- /dev/null +++ b/examples/zch/main.py @@ -0,0 +1,131 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import argparse +import time + +import torch + +from torchrec import EmbeddingConfig, KeyedJaggedTensor +from torchrec.distributed.benchmark.benchmark_utils import get_inputs +from tqdm import tqdm + +from .sparse_arch import SparseArch + + +def main(args: argparse.Namespace) -> None: + """ + This function tests the performance of a Sparse module with or without the MPZCH feature. + Arguments: + use_mpzch: bool, whether to enable MPZCH or not + Prints: + duration: time for a forward pass of the Sparse module with or without MPZCH enabled + collision_rate: the collision rate of the MPZCH feature + """ + print(f"Is use MPZCH: {args.use_mpzch}") + + # check available devices + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + # device = torch.device("cpu") + + print(f"Using device: {device}") + + # create an embedding configuration + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=args.num_embeddings_per_table, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=args.num_embeddings_per_table, + ), + ] + + # generate kjt input list + input_kjt_list = [] + for _ in range(args.num_iters): + input_kjt_single = KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + # pick a set of 24 random numbers from 0 to args.num_embeddings_per_table + values=torch.LongTensor( + list( + torch.randint( + 0, args.num_embeddings_per_table, (3 * args.batch_size,) + ) + ) + ), + lengths=torch.LongTensor([1] * args.batch_size + [2] * args.batch_size), + weights=None, + ) + input_kjt_single = input_kjt_single.to(device) + input_kjt_list.append(input_kjt_single) + + num_requests = args.num_iters * args.batch_size + + # make the model + model = SparseArch( + tables=embedding_config, + device=device, + return_remapped=True, + use_mpzch=args.use_mpzch, + buckets=1, + ) + + # do the forward pass + if device.type == "cuda": + torch.cuda.synchronize() + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + + # record the start time + starter.record() + for it_idx in tqdm(range(args.num_iters)): + # ec_out, remapped_ids_out = model(input_kjt_single) + input_kjt = input_kjt_list[it_idx].to(device) + ec_out, remapped_ids_out = model(input_kjt) + # record the end time + ender.record() + # wait for the end time to be recorded + torch.cuda.synchronize() + duration = starter.elapsed_time(ender) / 1000.0 # convert to seconds + else: + # in cpu mode, MPZCH can only run in inference mode, so we profile the model in eval mode + model.eval() + if args.use_mpzch: + # when using MPZCH modules, we need to manually set the modules to be in inference mode + # pyre-ignore + model._mc_ec._managed_collision_collection._managed_collision_modules[ + "table_0" + ].reset_inference_mode() + # pyre-ignore + model._mc_ec._managed_collision_collection._managed_collision_modules[ + "table_1" + ].reset_inference_mode() + + start_time = time.time() + for it_idx in tqdm(range(args.num_iters)): + input_kjt = input_kjt_list[it_idx].to(device) + ec_out, remapped_ids_out = model(input_kjt) + end_time = time.time() + duration = end_time - start_time + # get qps + qps = num_requests / duration + print(f"qps: {qps}") + # print the duration + print(f"duration: {duration} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--use_mpzch", action="store_true", default=False) + parser.add_argument("--num_iters", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--num_embeddings_per_table", type=int, default=1000) + args: argparse.Namespace = parser.parse_args() + main(args) diff --git a/examples/zch/sparse_arch.py b/examples/zch/sparse_arch.py new file mode 100644 index 000000000..b8be4abaa --- /dev/null +++ b/examples/zch/sparse_arch.py @@ -0,0 +1,137 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from torchrec import ( + EmbeddingCollection, + EmbeddingConfig, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) + +# For MPZCH +from torchrec.modules.hash_mc_evictions import ( + HashZchEvictionConfig, + HashZchEvictionPolicyName, +) + +# For MPZCH +from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection + +# For original MC +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) + +""" +Class SparseArch +An example of SparseArch with 2 tables, each with 2 features. +It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features +and returns the corresponding embeddings. + +Parameters: + tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table + device(torch.device): device on which the embedding table should be placed + buckets(int): number of buckets for each table + input_hash_size(int): input hash size for each table + return_remapped(bool): whether to return remapped features, if so, the return will be + a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be + a tuple of (Embedding(KeyedTensor), None) + is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table + use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module, + otherwise, it will use original MC managed collision module +""" + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + buckets: int = 4, + input_hash_size: int = 4000, + return_remapped: bool = False, + is_inference: bool = False, + use_mpzch: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + + if ( + use_mpzch + ): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table + mc_modules["table_0"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=( + tables[0].num_embeddings + ), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table + input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space + device=device, # the device on which the embedding table should be placed + total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live) + eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable + features=[ + "feature_0" + ], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature + single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour. + ), + ) + mc_modules["table_1"] = HashZchManagedCollisionModule( + is_inference=is_inference, + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + total_num_buckets=buckets, + eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, + eviction_config=HashZchEvictionConfig( + features=["feature_1"], + single_ttl=1, + ), + ) + else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table + mc_modules["table_0"] = MCHManagedCollisionModule( + zch_size=(tables[0].num_embeddings), + input_hash_size=input_hash_size, + device=device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + mc_modules["table_1"] = MCHManagedCollisionModule( + zch_size=(tables[1].num_embeddings), + device=device, + input_hash_size=input_hash_size, + eviction_interval=1, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + return self._mc_ec(kjt) diff --git a/examples/zch/zero_collision_hash_tutorial.ipynb b/examples/zch/zero_collision_hash_tutorial.ipynb new file mode 100644 index 000000000..09901f2a9 --- /dev/null +++ b/examples/zch/zero_collision_hash_tutorial.ipynb @@ -0,0 +1,452 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zero-collision Hash Tutorial\n", + "This example notebook goes through the following topics:\n", + "- Why do we need zero-collision hash?\n", + "- How to use the zero-collision module in TorchRec?\n", + "\n", + "## Pre-requisite\n", + "Before dive into the details, let's import all the necessary packages first. This needs you to [have the latest `torchrec` library installed](https://docs.pytorch.org/torchrec/setup-torchrec.html#installation)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "output": { + "id": 1411476299977597, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0608 234942.517 _utils_internal.py:282] NCCL_DEBUG env var is set to None\n", + "I0608 234942.519 _utils_internal.py:291] NCCL_DEBUG is WARN from /etc/nccl.conf\n", + "I0608 234949.804 pyper_torch_elastic_logging_utils.py:225] initialized PyperTorchElasticEventHandler\n" + ] + }, + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torchrec.modules.hash_mc_modules'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)\n", + "Cell \u001b[0;32mIn[11], line 11\u001b[0m\n", + "\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m nn\n", + "\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchrec\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n", + "\u001b[1;32m 4\u001b[0m EmbeddingCollection,\n", + "\u001b[1;32m 5\u001b[0m EmbeddingConfig,\n", + "\u001b[0;32m (...)\u001b[0m\n", + "\u001b[1;32m 8\u001b[0m KeyedTensor,\n", + "\u001b[1;32m 9\u001b[0m )\n", + "\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchrec\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodules\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhash_mc_modules\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HashZchManagedCollisionModule\n", + "\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchrec\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodules\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmc_embedding_modules\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ManagedCollisionEmbeddingCollection\n", + "\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchrec\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodules\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhash_mc_evictions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n", + "\u001b[1;32m 15\u001b[0m HashZchEvictionConfig,\n", + "\u001b[1;32m 16\u001b[0m HashZchEvictionPolicyName,\n", + "\u001b[1;32m 17\u001b[0m )\n", + "\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchrec.modules.hash_mc_modules'" + ] + }, + { + "data": { + "application/notebook-debug-button": "{\n\t\"notebookUri\": \"file:///data/users/lizhouyu/fbsource/fbcode/torchrec/github/examples/zch/zero_collision_hash_tutorial.ipynb\"\n}" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchrec import (\n", + " EmbeddingCollection,\n", + " EmbeddingConfig,\n", + " JaggedTensor,\n", + " KeyedJaggedTensor,\n", + " KeyedTensor,\n", + ")\n", + "\n", + "from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule\n", + "from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection\n", + "\n", + "from torchrec.modules.hash_mc_evictions import (\n", + " HashZchEvictionConfig,\n", + " HashZchEvictionPolicyName,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hash and Zero Collision Hash\n", + "In this section, we present the motivation that\n", + "- Why do we need to perform hash on incoming features?\n", + "- Why do we need to implement zero-collision hash?\n", + "\n", + "Let's first take a look in the question that why do we need to perform hashing for sparse feature inputs in the recommendation model? \n", + "We firstly create an embedding table of 1000 embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# define the number of embeddings\n", + "num_embeddings = 1000\n", + "table_config = EmbeddingConfig(\n", + " name=\"t1\",\n", + " embedding_dim=16,\n", + " num_embeddings=1000,\n", + " feature_names=[\"f1\"],\n", + ")\n", + "ec = EmbeddingCollection(tables=[table_config])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Usually, for each input sparse feature ID, we regard it as the index of the embedding in the embedding table, and fetch the embedding at the corresponding slot in the embedding table. However, while embedding tables is fixed when instantiating the models, the number of sparse features, such as tags of videos, can keep growing. After a while, the ID of a sparse feature can be larger the size of our embedding table." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "feature_id = num_embeddings + 1\n", + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([feature_id]),\n", + " lengths=torch.tensor([1]),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At that point, the query will lead to an `index out of range` error." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " feature_embedding = ec(input_kjt)\n", + "except IndexError as e:\n", + " print(f\"Query the embedding table of size {num_embeddings} with sparse feature ID {input_kjt['f1'].values()}\")\n", + " print(f\"This query throws an IndexError: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To avoid this error from happening, we hash the sparse feature ID to a value within the range of the embedding table size, and use the hashed value as the feature ID to query the embedding table. \n", + "\n", + "For the purpose of demonstration, we use Python's built-in hash function to hash an integer (which will not change the value) and remap it to the range of `[0, num_embeddings)` by taking the modulo of `num_embeddings`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def remap(input_jt_value: int, num_embeddings: int):\n", + " input_hash = hash(input_jt_value)\n", + " return input_hash % num_embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can query the embedding table with the remapped id without error." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "remapped_id = remap(feature_id, num_embeddings)\n", + "remapped_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([remapped_id]),\n", + " lengths=torch.tensor([1]),\n", + ")\n", + "feature_embedding = ec(remapped_kjt)\n", + "print(f\"Query the embedding table of size {num_embeddings} with remapped sparse feature ID {remapped_id} from original ID {feature_id}\")\n", + "print(f\"This query does not throw an IndexError, and returns the embedding of the remapped ID: {feature_embedding}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After answering the first question: __Why do we need to perform hash on incoming features?__, now we can answer the second question: __Why do we need to implement zero-collision hash?__\n", + "\n", + "Because we are casting a larger range of values into a small range, there will be some values being remapped to the same index. For example, using our `remap` function, it will give the same remapped id for feature `num_embeddings + 1` and `1`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "feature_id_1 = 1\n", + "feature_id_2 = num_embeddings + 1\n", + "remapped_feature_id_1 = remap(feature_id_1, num_embeddings)\n", + "remapped_feature_id_2 = remap(feature_id_2, num_embeddings)\n", + "print(f\"feature ID {feature_id_1} is remapped to ID {remapped_feature_id_1}\")\n", + "print(f\"feature ID {feature_id_2} is remapped to ID {remapped_feature_id_2}\")\n", + "print(f\"Check if remapped feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {remapped_feature_id_1 == remapped_feature_id_2}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, two totally different features can share the same embedding. The situation when two feature IDs share the same remapped ID is called a **hash collision**." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"f1\"],\n", + " values=torch.tensor([remapped_feature_id_1, remapped_feature_id_2]),\n", + " lengths=torch.tensor([1, 1]),\n", + ")\n", + "feature_embeddings = ec(input_kjt)\n", + "feature_id_1_embedding = feature_embeddings[\"f1\"].values()[0]\n", + "feature_id_2_embedding = feature_embeddings[\"f1\"].values()[1]\n", + "print(f\"Embedding of feature ID {remapped_feature_id_1} is {feature_id_1_embedding}\")\n", + "print(f\"Embedding of feature ID {remapped_feature_id_2} is {feature_id_2_embedding}\")\n", + "print(f\"Check if the embeddings of feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {torch.equal(feature_id_1_embedding, feature_id_2_embedding)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Making two different (and potentially totally irrelavant) features share the same embedding will cause inaccurate recommendations.\n", + "Lukily, for many sparse features, though their range can be larger than the the embedding table size, their IDs are sparsely located on the range.\n", + "In some other cases, the embedding table may only receive frequent queries for a subset of the features.\n", + "So we can design some __managed collision hash__ modules to avoid the hash collision from happening." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TorchRec Zero Collision Hash Modules\n", + "\n", + "TorchRec implements managed collision hash strategies such as *sorted zero collision hash* and *multi-probe zero collision hash (MPZCH)*.\n", + "\n", + "They help hash and remap the feature IDs to embedding table indices with (near-)zero collisions.\n", + "\n", + "In the following content we will use the MPZCH module as an example for how to use the zero-collision modules in TorchRec. The name of the MPZCH module is `HashZchManagedCollisionModule`.\n", + "\n", + "Let's assume we have two tables: `table_0` and `table_1`, each with embeddings for `feature_0` and `feature_1`, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# define the table sizes\n", + "num_embeddings_table_0 = 1000\n", + "num_embeddings_table_1 = 2000\n", + "\n", + "# create table configs\n", + "table_0_config = EmbeddingConfig(\n", + " name=\"table_0\",\n", + " embedding_dim=16,\n", + " num_embeddings=num_embeddings_table_0,\n", + " feature_names=[\"feature_0\"],\n", + ")\n", + "\n", + "table_1_config = EmbeddingConfig(\n", + " name=\"table_1\",\n", + " embedding_dim=16,\n", + " num_embeddings=num_embeddings_table_1,\n", + " feature_names=[\"feature_1\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before turning the table configurations into embedding table collection, we instantiate our managed collision modules.\n", + "\n", + "The managed collision modules for a collection of embedding tables are intended to format as a dictionary with `{table_name: mc_module_for_the_table}`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "mc_modules = {}\n", + "\n", + "# Instantiate the module, we provide detailed comments on\n", + "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #\n", + "buckets = 4\n", + "input_hash_size = 10000\n", + "mc_modules[\"table_0\"] = HashZchManagedCollisionModule(\n", + " is_inference=False, # whether the module is used in inference or not\n", + " zch_size=(\n", + " table_0_config.num_embeddings\n", + " ), # the size of the embedding table\n", + " input_hash_size=input_hash_size, # though the name is input hash size, it refers to the size of the input id space\n", + " device=device, # the device on which the embedding table should be placed\n", + " total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file\n", + " eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live)\n", + " eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable\n", + " features=[\n", + " \"feature_0\"\n", + " ], # because we only have one feature \"feature_0\" in this table, so we only need to specify the ttl for this feature\n", + " single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour.\n", + " ),\n", + ")\n", + "mc_modules[\"table_1\"] = HashZchManagedCollisionModule(\n", + " is_inference=False,\n", + " zch_size=(table_1_config.num_embeddings),\n", + " device=device,\n", + " input_hash_size=input_hash_size,\n", + " total_num_buckets=buckets,\n", + " eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,\n", + " eviction_config=HashZchEvictionConfig(\n", + " features=[\"feature_1\"],\n", + " single_ttl=1,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For embedding tables with managed collision modules, TorchRec uses a wrapper module `ManagedCollisionEmbeddingCollection` that contains both the embedding table collections and the managed collision modules. Users only need to pass their table configurations and" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "mc_ec = ManagedCollisionEmbeddingCollection = (\n", + " ManagedCollisionEmbeddingCollection(\n", + " EmbeddingCollection(\n", + " tables=[\n", + " table_0_config,\n", + " table_1_config\n", + " ],\n", + " device=device,\n", + " ),\n", + " ManagedCollisionCollection(\n", + " managed_collision_modules=mc_modules,\n", + " embedding_configs=tables=[\n", + " table_0_config,\n", + " table_1_config\n", + " ],\n", + " ),\n", + " return_remapped_features=True, # whether to return the remapped feature IDs\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ManagedCollisionEmbeddingCollection` module will perform remapping and table look-up for the input. Users only need to pass the keyyed jagged tensor queries into the module." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n", + " keys=[\"feature_0\", \"feature_1\"],\n", + " values=torch.tensor([1000, 10001, 2000, 20001]),\n", + " lengths=torch.tensor([1, 1, 1, 1]),\n", + ")\n", + "for feature_name, feature_jt in input_kjt.to_dict().items():\n", + " print(f\"feature name: {feature_name}, feature jt: {feature_jt}\")\n", + " print(f\"feature jt values: {feature_jt.values()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "output_embeddings, remapped_ids = mc_ec(kjt)\n", + "# show output embeddings\n", + "for feature_name, feature_embedding in output_embeddings.to_dict().items():\n", + " print(f\"feature name: {feature_name}, feature embedding: {feature_embedding}\")\n", + "# show remapped ids\n", + "for feature_name, feature_jt in remapped_ids.to_dict().items():\n", + " print(f\"feature name: {feature_name}, feature jt values: {feature_jt.values()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have a basic example of how to use the managed collision modules in TorchRec. \n", + "\n", + "We also provide a profiling example to compare the efficiency of sorted ZCH and MPZCH modules. Check the [Readme](Readme.md) file for more details." + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "ee3845a2-a85b-4a8e-8c42-9ce4690c9956", + "isAdHoc": false, + "kernelspec": { + "display_name": "torchrec", + "language": "python", + "name": "bento_kernel_torchrec" + }, + "language_info": { + "name": "plaintext" + }, + "orig_nbformat": 4 + } +} diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py index fe5a0ce19..0001d19f9 100644 --- a/torchrec/modules/hash_mc_modules.py +++ b/torchrec/modules/hash_mc_modules.py @@ -191,7 +191,7 @@ class HashZchManagedCollisionModule(ManagedCollisionModule): def __init__( self, - zch_size: int, + zch_size: int, # number of embeddings in the table device: torch.device, total_num_buckets: int, max_probe: int = 128, @@ -213,9 +213,12 @@ def __init__( zch_size % total_num_buckets == 0 ), f"please pass output segments if not uniform buckets {zch_size=}, {total_num_buckets=}" output_segments = [ - (zch_size // total_num_buckets) * bucket - for bucket in range(total_num_buckets + 1) - ] + (zch_size // total_num_buckets) + * bucket # output_segments is a list of first indices of each bucket, along with the last index of the last bucket + for bucket in range( + total_num_buckets + 1 + ) # for example, if the table size is 100 and total_num_buckets is 4, then the table is divided into 4 buckets, each with size 100//4=25, and the first indices of each bucket are 0, 25, 50, 75 + ] # combining with the last index of the last bucket, the output_segments is [0, 25, 50, 75, 100] super().__init__( device=device,