Skip to content

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Aug 29, 2025

Issue #5079.

For multidevice, inferring tensor shapes relies on allocation domain. Rfactor does not replay the allocation domain, which led to wrong shape being inferred in the given test case and issue.

@Priya2698
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Aug 29, 2025

Review updated until commit 8888217

Description

  • Fix shape inference in multidevice by replaying allocation domain

  • Replay rfactor transformations on allocation domain for correctness

  • Propagate static allocation IDs to preserve compute semantics

  • Add test for inner reduction in multidevice context


Changes walkthrough 📝

Relevant files
Bug fix
transform_rfactor.cpp
Extend rFactor replay to allocation domain                             

csrc/transform_rfactor.cpp

  • Added allocation_domain_ to track allocation domain during rFactor
  • Modified splitId and mergeId to handle allocation domain with static
    IDs
  • Updated ReplayRFactor constructor to compute static allocation IDs
  • Set allocation domain on producer and consumer in runReplay
  • +101/-55
    Tests
    test_multidevice.py
    Add test for inner reduction on multidevice                           

    tests/python/multidevice/test_multidevice.py

  • Added new test test_inner_reduction for multidevice
  • Tests inner dimension reduction with device mesh
  • Verifies correct output using assert torch.allclose
  • +28/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The updateRFactorDomain function unconditionally processes both logical and allocation domains for splits and merges, but the original code had conditional checks to ensure only static IDs were processed. This may lead to unintended modifications of non-static allocation domains.

    if (Split* split = dynamic_cast<Split*>(expr)) {
      splitId(logical_domain_, split, static_logical_ids_);
      splitId(allocation_domain_, split, static_allocation_ids_);
    }
    if (Merge* merge = dynamic_cast<Merge*>(expr)) {
      mergeId(logical_domain_, merge, static_logical_ids_);
      mergeId(allocation_domain_, merge, static_allocation_ids_);
    }
    Logic Error

    The mergeId function checks that both inner and outer IDs have the same static status, but only uses the outer ID's status to decide whether to proceed. This could result in inconsistent handling of merge operations when one ID is static and the other is not, despite the NVF_ERROR check.

    NVF_ERROR(
        static_ids.contains(merge->inner()) ==
            static_ids.contains(merge->outer()),
        "If one input to a merge is a static id, the other must be as well.");
    if (!static_ids.contains(merge->outer())) {
      return;
    }
    auto outer_it = domain.erase(merge->outer()).second;
    domain.insert(outer_it, merge->out(), std::monostate());
    domain.erase(merge->inner());
    Missing Validation

    The allocation domain is now being replayed and set unconditionally when present, but there is no validation that the transformed allocation domain maintains necessary properties or aligns with the logical domain structure.

    if (original_td->hasAllocation()) {
      std::vector<IterDomain*> transformed_original_allocation =
          replay_rfactor.allocation();
      std::vector<IterDomain*> new_producer_allocation_domain = replayDomain(
          transformed_original_allocation,
          original_to_producer_id_map,
          /*ignore_ids=*/{},
          /*propagate_padding=*/false,
          /*propagate_parallelization=*/false);
      producer_domain->setAllocationDomain(
          new_producer_allocation_domain,
          TensorDomain::getContiguityFilledWith(
              new_producer_allocation_domain, true));
    }

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Priya2698 added a commit that referenced this pull request Sep 3, 2025
    This approach should be faster and simplifies the update logic.
    
    For PR #5090
    Priya2698 added a commit that referenced this pull request Sep 4, 2025
    Separating out replaying domain to a function to allow reuse.
    For Issue #5079,  PR #5090 will also replay allocation.
    // Axes in the original_td that are in the history of the rfactored domains.
    // These will mark which iter domains must be preserved as static
    // transformations to preserve compute semantics.
    auto all_deps_of_logical = DependencyCheck::getAllValsBetween(
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    code movement: moved to ReplayRFactor constructor. We compute vals between maybeRoot/logical/allocation to rfactor_axes.

    @Priya2698
    Copy link
    Collaborator Author

    !test --diff

    @Priya2698
    Copy link
    Collaborator Author

    !test --diff

    @Priya2698 Priya2698 marked this pull request as ready for review September 5, 2025 17:31
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant