Skip to content

Update column_wise to allow uneven sharding #3050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _initialize_torch_state( # noqa: C901
# pyre-fixme[16]: `ShardedQuantEmbeddingModuleState` has no attribute
# `_table_name_to_tensors_list_qbias`.
self._table_name_to_tensors_list_qbias: Dict[str, List[torch.Tensor]] = {}
_table_name_to_shard_idx: Dict[str, int] = {}

for tbe, config in tbes.items():
for (tbe_split_w, tbe_split_qscale, tbe_split_qbias), table in zip(
Expand All @@ -165,6 +166,12 @@ def _initialize_torch_state( # noqa: C901
metadata: ShardMetadata = copy.deepcopy(table.local_metadata)
metadata.shard_sizes = [tbe_split_w.size(0), tbe_split_w.size(1)]

if table.name not in _table_name_to_shard_idx:
_table_name_to_shard_idx[table.name] = 0

column_idx = _table_name_to_shard_idx[table.name]
_table_name_to_shard_idx[table.name] = column_idx + 1

# TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta"
# pyre-ignore
if metadata.placement.device != tbe_split_w.device:
Expand Down Expand Up @@ -199,12 +206,6 @@ def _initialize_torch_state( # noqa: C901
]:
assert table.local_metadata
metadata: ShardMetadata = copy.deepcopy(table.local_metadata)
shard_sizes = metadata.shard_sizes
shard_offsets = metadata.shard_offsets

shard_sizes_cols = shard_sizes[1]
shard_offsets_cols = shard_offsets[1]

parameter_sharding: ParameterSharding = (
table_name_to_parameter_sharding[table.name]
)
Expand All @@ -221,7 +222,6 @@ def _initialize_torch_state( # noqa: C901
torch.empty([])
] * num_shards

column_idx = int(shard_offsets_cols / shard_sizes_cols)
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[No...
table_name_to_tensors_list[table.name][
column_idx
Expand Down
50 changes: 35 additions & 15 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ def _parameter_sharding_generator(


def column_wise(
ranks: List[int],
ranks: Optional[List[int]] = None,
size_per_rank: Optional[List[int]] = None,
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::COLUMN_WISE` for construct_module_sharding_plan.
Expand All @@ -648,22 +649,41 @@ def _parameter_sharding_generator(
device_type: str,
sharder: ModuleSharder[nn.Module],
) -> ParameterSharding:
if param.shape[1] % len(ranks) != 0:
raise ValueError(
f"column dim of {param.shape[1]} cannot be evenly divided across {ranks}"
if size_per_rank is None:
assert ranks is not None

if param.shape[1] % len(ranks) != 0:
raise ValueError(
f"column dim of {param.shape[1]} cannot be evenly divided across {ranks}"
)
shard_dim = param.shape[1] // len(ranks)
size_and_offsets = _get_parameter_size_offsets(
param,
ShardingType.COLUMN_WISE,
local_size,
world_size,
col_wise_shard_dim=shard_dim,
)
shard_dim = param.shape[1] // len(ranks)
size_and_offsets = _get_parameter_size_offsets(
param,
ShardingType.COLUMN_WISE,
local_size,
world_size,
col_wise_shard_dim=shard_dim,
)

size_offset_ranks = []
for (size, offset), rank in zip(size_and_offsets, ranks):
size_offset_ranks.append((size, offset, rank))
size_offset_ranks = []
for (size, offset), rank in zip(size_and_offsets, ranks):
size_offset_ranks.append((size, offset, rank))
else:
size_offset_ranks = []
(rows, cols) = param.shape
cur_offset = 0
prev_offset = 0
for rank, cur_size in enumerate(size_per_rank):
cur_offset += cur_size
cur_offset = min(cur_offset, cols)
cur_cols = cur_offset - prev_offset
size_offset_ranks.append(([rows, cur_cols], [0, prev_offset], rank))
prev_offset = cur_offset

if cur_offset < cols:
raise ValueError(
f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {size_per_rank}"
)

return _get_parameter_sharding(
param,
Expand Down
123 changes: 123 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,129 @@ def test_cw(
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
weight_dtype=st.sampled_from([torch.qint8]),
device_type=st.sampled_from(["cuda"]),
)
@settings(max_examples=4, deadline=None)
def test_uneven_cw(self, weight_dtype: torch.dtype, device_type: str) -> None:
num_embeddings = 64
emb_dim = 512
dim_1 = 63
dim_2 = 128
dim_3 = 65
dim_4 = 256
local_size = 4
world_size = 4
batch_size = 4
local_device = torch.device(f"{device_type}:0")
mi = create_test_model(
num_embeddings,
emb_dim,
world_size,
batch_size,
dense_device=local_device,
sparse_device=local_device,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
)
non_sharded_model = mi.quant_model
expected_ranks: List[int] = [0, 1, 2, 3]
expected_shards = [
[
(
(0, 0, num_embeddings, dim_1),
placement(device_type, expected_ranks[0], world_size),
),
(
(0, dim_1, num_embeddings, dim_2),
placement(device_type, expected_ranks[1], world_size),
),
(
(0, dim_1 + dim_2, num_embeddings, dim_3),
placement(device_type, expected_ranks[2], world_size),
),
(
(0, dim_1 + dim_2 + dim_3, num_embeddings, dim_4),
placement(device_type, expected_ranks[3], world_size),
),
]
]
sharder = TestQuantEBCSharder(
sharding_type=ShardingType.COLUMN_WISE.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
)
module_plan = construct_module_sharding_plan(
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no
# attribute `sparse`.
non_sharded_model._module.sparse.ebc,
per_param_sharding={
"table_0": column_wise(size_per_rank=[dim_1, dim_2, dim_3, dim_4]),
},
# pyre-ignore
sharder=sharder,
local_size=local_size,
world_size=world_size,
device_type=device_type,
)
plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan})

sharded_model = shard_qebc(
mi=mi,
sharding_type=ShardingType.COLUMN_WISE,
device=local_device,
expected_shards=expected_shards,
plan=plan,
)

inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
torch.testing.assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
torch.testing.assert_close(sharded_output, gm_script_output)

weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model)

for table_name, expected_shard in zip(["table_0"], expected_shards):
unsharded_weight_fqn = (
f"_module.sparse.ebc.embedding_bags.{table_name}.weight"
)
for (offset_r, offset_c, size_r, size_c), rank in expected_shard:
tbe_idx = int(rank.split(":")[-1])
sharded_weight_fqn: str = (
f"_module.sparse.ebc.tbes.{tbe_idx}.0.{table_name}.weight"
)

assert sharded_weight_fqn in weights_spec
wspec = weights_spec[sharded_weight_fqn]
assert wspec.fqn == unsharded_weight_fqn
assert wspec.shard_sizes == [size_r, size_c]
assert wspec.shard_offsets == [offset_r, offset_c]
assert wspec.sharding_type == ShardingType.COLUMN_WISE.value

for qcomp in ["qscale", "qbias"]:
sharded_weight_qcomp_fqn: str = f"{sharded_weight_fqn}_{qcomp}"
assert sharded_weight_qcomp_fqn in weights_spec
wqcomp_spec = weights_spec[sharded_weight_qcomp_fqn]
assert wqcomp_spec.fqn == f"{unsharded_weight_fqn}_{qcomp}"
assert wqcomp_spec.shard_sizes == [size_r, 2]
assert wqcomp_spec.shard_offsets == [0, 0]
assert wqcomp_spec.sharding_type == ShardingType.COLUMN_WISE.value

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down
Loading