Skip to content

Create an example for MPZCH #3063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/zch/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Managed Collision Hash Example

This example demonstrates the usage of managed collision hash feature in TorchRec, which is designed to efficiently handle hash collisions in embedding tables. We include two implementations of the feature: sorted managed collision Hash (MCH) and MPZCH (Multi-Probe Zero Collision Hash).

## Folder Structure

```
managed_collision_hash/
├── Readme.md # This documentation file
├── __init__.py # Python package marker
├── main.py # Main script to run the benchmark
└── sparse_arch.py # Implementation of the sparse architecture with managed collision
└── zero_collision_hash_tutorial.ipynb # Jupyter notebook for the motivation of zero collision hash and the use of zero collision hash modules in TorchRec
```

### Introduction of MPZCH

Multi-probe Zero Collision Hash (MPZCH) is a technique that can be used to reduce the collision rate for embedding table lookups. For the concept of hash collision and why we need to manage the collision, please refer to the [zero collision hash tutorial](zero_collision_hash_tutorial.ipynb).

A MPZCH module contains two essential tables: the identity table and the metadata table.
The identity table is used to record the mapping from input hash value to the remapped ID. The value in each identity table slot is an input hash value, and that hash value's remmaped ID is the index of the slot.
The metadata table share the same length as the identity table. The time when a hash value is inserted into a identity table slot is recorded in the same-indexed slot of the metadata table.

Specifically, MPZCH include the following two steps:
1. **First Probe**: Check if there are available or evictable slots in its identity table.
2. **Second Probe**: Check if the slot for indexed with the input hash value is occupied. If not, directly insert the input hash value into that slot. Otherwise, perform a linear probe to find the next available slot. If all the slots are occupied, find the next evictable slot whose value has stayed in the table for a time longer than a threshold, and replace the expired hash value with the input one.

The use of MPZCH module `HashZchManagedCollisionModule` are introduced with detailed comments in the [sparse_arch.py](sparse_arch.py) file.

The module can be configured to use different eviction policies and parameters.

The detailed function calls are shown in the diagram below
![MPZCH Module Data Flow](docs/mpzch_module_dataflow.png)

#### Relationship among Important Parameters

The `HashZchManagedCollisionModule` module has three important parameters for initialization
- `num_embeddings`: the number of embeddings in the embedding table
- `num_buckets`: the number of buckets in the hash table

The `num_buckets` is used as the minimal sharding unit for the embedding table. Because we are doing linear probe in MPZCH, when resharding the embedding table, we want to avoid separate the remapped index of an input feature ID and its hash value to different ranks. So we make sure they are in the same bucket, and move the whole bucket during resharding.

## Usage
We also prepare a profiling example of an Sparse Arch implemented with different ZCH techniques.
To run the profiling example with sorted ZCH:

```bash
python main.py
```

To run the profiling example with MPZCH:

```bash
python main.py --use_mpzch
```

You can also specify the `batch_size`, `num_iters`, and `num_embeddings_per_table`:
```bash
python main.py --use_mpzch --batch_size 8 --num_iters 100 --num_embeddings_per_table 1000
```

The example allows you to compare the QPS of embedding operations with sorted ZCH and MPZCH. On our server with A100 GPU, the initial QPS benchmark results with `batch_size=8`, `num_iters=100`, and `num_embeddings_per_table=1000` is presented in the table below:

| ZCH module | QPS |
| --- | --- |
| sorted ZCH | 1371.6942797862002 |
| MPZCH | 2750.4449443587414 |

And with `batch_size=1024`, `num_iters=1000`, and `num_embeddings_per_table=1000` is

| ZCH module | QPS |
| --- | --- |
| sorted ZCH | 263827.54955056956 |
| MPZCH | 551306.9687760604 |
Empty file added examples/zch/__init__.py
Empty file.
Binary file added examples/zch/docs/mpzch_module_dataflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
131 changes: 131 additions & 0 deletions examples/zch/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

import argparse
import time

import torch

from torchrec import EmbeddingConfig, KeyedJaggedTensor
from torchrec.distributed.benchmark.benchmark_utils import get_inputs
from tqdm import tqdm

from .sparse_arch import SparseArch


def main(args: argparse.Namespace) -> None:
"""
This function tests the performance of a Sparse module with or without the MPZCH feature.
Arguments:
use_mpzch: bool, whether to enable MPZCH or not
Prints:
duration: time for a forward pass of the Sparse module with or without MPZCH enabled
collision_rate: the collision rate of the MPZCH feature
"""
print(f"Is use MPZCH: {args.use_mpzch}")

# check available devices
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")

print(f"Using device: {device}")

# create an embedding configuration
embedding_config = [
EmbeddingConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=8,
num_embeddings=args.num_embeddings_per_table,
),
EmbeddingConfig(
name="table_1",
feature_names=["feature_1"],
embedding_dim=8,
num_embeddings=args.num_embeddings_per_table,
),
]

# generate kjt input list
input_kjt_list = []
for _ in range(args.num_iters):
input_kjt_single = KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
# pick a set of 24 random numbers from 0 to args.num_embeddings_per_table
values=torch.LongTensor(
list(
torch.randint(
0, args.num_embeddings_per_table, (3 * args.batch_size,)
)
)
),
lengths=torch.LongTensor([1] * args.batch_size + [2] * args.batch_size),
weights=None,
)
input_kjt_single = input_kjt_single.to(device)
input_kjt_list.append(input_kjt_single)

num_requests = args.num_iters * args.batch_size

# make the model
model = SparseArch(
tables=embedding_config,
device=device,
return_remapped=True,
use_mpzch=args.use_mpzch,
buckets=1,
)

# do the forward pass
if device.type == "cuda":
torch.cuda.synchronize()
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)

# record the start time
starter.record()
for it_idx in tqdm(range(args.num_iters)):
# ec_out, remapped_ids_out = model(input_kjt_single)
input_kjt = input_kjt_list[it_idx].to(device)
ec_out, remapped_ids_out = model(input_kjt)
# record the end time
ender.record()
# wait for the end time to be recorded
torch.cuda.synchronize()
duration = starter.elapsed_time(ender) / 1000.0 # convert to seconds
else:
# in cpu mode, MPZCH can only run in inference mode, so we profile the model in eval mode
model.eval()
if args.use_mpzch:
# when using MPZCH modules, we need to manually set the modules to be in inference mode
# pyre-ignore
model._mc_ec._managed_collision_collection._managed_collision_modules[
"table_0"
].reset_inference_mode()
# pyre-ignore
model._mc_ec._managed_collision_collection._managed_collision_modules[
"table_1"
].reset_inference_mode()

start_time = time.time()
for it_idx in tqdm(range(args.num_iters)):
input_kjt = input_kjt_list[it_idx].to(device)
ec_out, remapped_ids_out = model(input_kjt)
end_time = time.time()
duration = end_time - start_time
# get qps
qps = num_requests / duration
print(f"qps: {qps}")
# print the duration
print(f"duration: {duration} seconds")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_mpzch", action="store_true", default=False)
parser.add_argument("--num_iters", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_embeddings_per_table", type=int, default=1000)
args: argparse.Namespace = parser.parse_args()
main(args)
137 changes: 137 additions & 0 deletions examples/zch/sparse_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from torchrec import (
EmbeddingCollection,
EmbeddingConfig,
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
)

# For MPZCH
from torchrec.modules.hash_mc_evictions import (
HashZchEvictionConfig,
HashZchEvictionPolicyName,
)

# For MPZCH
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection

# For original MC
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
ManagedCollisionCollection,
MCHManagedCollisionModule,
)

"""
Class SparseArch
An example of SparseArch with 2 tables, each with 2 features.
It looks up the corresponding embedding for incoming KeyedJaggedTensors with 2 features
and returns the corresponding embeddings.

Parameters:
tables(List[EmbeddingConfig]): List of EmbeddingConfig that defines the embedding table
device(torch.device): device on which the embedding table should be placed
buckets(int): number of buckets for each table
input_hash_size(int): input hash size for each table
return_remapped(bool): whether to return remapped features, if so, the return will be
a tuple of (Embedding(KeyedTensor), Remapped_ID(KeyedJaggedTensor)), otherwise, the return will be
a tuple of (Embedding(KeyedTensor), None)
is_inference(bool): whether to use inference mode. In inference mode, the module will not update the embedding table
use_mpzch(bool): whether to use MPZCH or not. If true, the module will use MPZCH managed collision module,
otherwise, it will use original MC managed collision module
"""


class SparseArch(nn.Module):
def __init__(
self,
tables: List[EmbeddingConfig],
device: torch.device,
buckets: int = 4,
input_hash_size: int = 4000,
return_remapped: bool = False,
is_inference: bool = False,
use_mpzch: bool = False,
) -> None:
super().__init__()
self._return_remapped = return_remapped

mc_modules = {}

if (
use_mpzch
): # if using the MPZCH module, we create a HashZchManagedCollisionModule for each table
mc_modules["table_0"] = HashZchManagedCollisionModule(
is_inference=is_inference,
zch_size=(
tables[0].num_embeddings
), # the zch size, that is, the size of local embedding table, should be the same as the size of the embedding table
input_hash_size=input_hash_size, # the input hash size, that is, the size of the input id space
device=device, # the device on which the embedding table should be placed
total_num_buckets=buckets, # the number of buckets, the detailed explanation of the use of buckets can be found in the readme file
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # the eviction policy name, in this example use the single ttl eviction policy, which assume an id is evictable if it has been in the table longer than the ttl (time to live)
eviction_config=HashZchEvictionConfig( # Here we need to specify for each feature, what is the ttl, that is, how long an id can stay in the table before it is evictable
features=[
"feature_0"
], # because we only have one feature "feature_0" in this table, so we only need to specify the ttl for this feature
single_ttl=1, # The unit of ttl is hour. Let's set the ttl to be default to 1, which means an id is evictable if it has been in the table for more than one hour.
),
)
mc_modules["table_1"] = HashZchManagedCollisionModule(
is_inference=is_inference,
zch_size=(tables[1].num_embeddings),
device=device,
input_hash_size=input_hash_size,
total_num_buckets=buckets,
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
eviction_config=HashZchEvictionConfig(
features=["feature_1"],
single_ttl=1,
),
)
else: # if not using the MPZCH module, we create a MCHManagedCollisionModule for each table
mc_modules["table_0"] = MCHManagedCollisionModule(
zch_size=(tables[0].num_embeddings),
input_hash_size=input_hash_size,
device=device,
eviction_interval=2,
eviction_policy=DistanceLFU_EvictionPolicy(),
)
mc_modules["table_1"] = MCHManagedCollisionModule(
zch_size=(tables[1].num_embeddings),
device=device,
input_hash_size=input_hash_size,
eviction_interval=1,
eviction_policy=DistanceLFU_EvictionPolicy(),
)

self._mc_ec: ManagedCollisionEmbeddingCollection = (
ManagedCollisionEmbeddingCollection(
EmbeddingCollection(
tables=tables,
device=device,
),
ManagedCollisionCollection(
managed_collision_modules=mc_modules,
embedding_configs=tables,
),
return_remapped_features=self._return_remapped,
)
)

def forward(
self, kjt: KeyedJaggedTensor
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:
return self._mc_ec(kjt)
Loading
Loading