Skip to content

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Oct 13, 2025

  • register_sharding for sharding propagation
  • Use distribute_tensor(..., src_data_rank=None) to use the local tensor as DTensor as packed fp4 dtype doesn't seem quite compatible with DTensor
  • Update desired layouts accordingly

@github-actions
Copy link

Description

  • Register sharding strategy for fp4 grouped matmul

  • Use Replicate instead of Shard for fp4 weights

  • Update input layouts for grouped linear layers

  • Fix distributed tensor handling for packed fp4 data


Changes walkthrough 📝

Relevant files
Enhancement
test_moe.py
Add sharding registration and fix fp4 tensor distribution

tests/python/multidevice/test_moe.py

  • Import register_sharding for custom op sharding registration
  • Define sharding strategies for fp4 grouped mm with Replicate for fp4
    tensors
  • Update _partition_fn to use Replicate() and src_data_rank=None for
    fp4_weight and b_sf
  • Extend input layouts to support 4 inputs with proper sharding
    annotations
  • +111/-21

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The sharding strategy for dropme in nvfuser_grouped_mm_sharding uses Shard(2), but it is unclear if this tensor has at least 3 dimensions. This may lead to a runtime error if dropme has fewer than 3 dimensions.

    Shard(2),  # dropme sharded on output feature dimension (last dim)
    Inconsistent Layout Update

    The input_layouts and desired_input_layouts are updated to include four inputs, but it should be verified that all call sites and consumers of these layouts are compatible with the new four-element tuple structure.

    self.input_layouts = input_layouts or (
        Shard(-1),
        Replicate(),
        Replicate(),
        Replicate(),
    )
    Redundant Replication

    The _partition_fn methods replicate fp4_weight and b_sf using distribute_tensor with Replicate() and src_data_rank=None, but it should be confirmed whether this replication is necessary given that the data is already present on each rank and not being sharded.

    module.fp4_weight = nn.Parameter(
        distribute_tensor(
            module.fp4_weight, device_mesh, [Replicate()], src_data_rank=None
        ),

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants