Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Oct 16, 2025

Add PERMISSIVE_RESIZE graph support to IdModel and incompatible reshape detection to fix issue #5358

Summary

This PR adds PERMISSIVE_RESIZE graph building capability to IdModel and implements a topology checker to detect incompatible reshape patterns that cannot be fused together.

Changes

1. IdModel: PERMISSIVE_RESIZE Graph Support

  • Added buildPermissiveResizeGraph() method to IdModel class
  • Implementation: Starts with PERMISSIVE graph, then for each consumer tensor's logical domain:
    • If an ID is defined by a Resize operation, directly maps resize->in() to the resize output ID. This map is regardless of extent (e.g., domains with extent 36, 24, and 12 can all be mapped together)
    • propagate to downstream ids through maybeMapThroughExprs

2. Incompatible Reshape Detection

  • Added SchedulerTopologyChecker::hasIncompatibleReshapes() function
  • Detection logic: Returns true when rfactor IDs are mapped together (same ValGroup) but have different transformations (different ExprGroups)
  • Use case: Detects when mapped domains are reshaped incompatibly
    • Example:
    auto tv0 = makeConcreteTensor({36});
    auto tv1 = slice(tv0, std::vector<int64_t>{0}, std::vector<int64_t>{24});
    auto tv2 = slice(tv0, std::vector<int64_t>{12}, std::vector<int64_t>{24});
    auto tv3 = reshape(tv1, {24}, {2, 3, 4});
    auto tv4 = reshape(tv2, {12}, {2, 2, 3});
    
    In this case, we have the following two extra disjoint set with different extents. The ids in the first set are mapped through resize. The ids in the 2nd set are mapped through Exprs (split by 2).
Disjoint val set: { iS3{36}rf; iS1{36}rf; iS0{36}; iS6{24}rf; iS2{24}rf; iS12{12}rf; iS4{12}rf }
Disjoint val set: { iS8{12}rf; iS14{6}rf }

The ids ( iS6{24}rf and iS12{12}rf) in the first disjoint set are used in the same expr set. However the ids in the 2nd set are used in different expr sets, which means different transforms.

Disjoint expr set: { 
  Outer split: iS6{24}rf by factor 2 -> iS7{2}rf, iS8{12}rf;
  Outer split: iS12{12}rf by factor 2 -> iS13{2}rf, iS14{6}rf}

Disjoint expr set: { Outer split: iS8{12}rf by factor 3 -> iS9{3}rf, iS10{4}rf }
Disjoint expr set: { Outer split: iS14{6}rf by factor 2 -> iS15{2}rf, iS16{3}rf } 

Technical Details

PERMISSIVE_RESIZE connects resize inputs and outputs regardless of extent changes and propagates those mappings through transformations based on matching operations (e.g., same split factor), not matching extents. This can map many downstream domains together, revealing when domains with a shared origin later undergo incompatible transformations. The incompatibility check then prevents fusion in cases where domains are mapped via resize (same origin) but diverge in their subsequent transformations (making unified scheduling invalid).

@github-actions
Copy link

github-actions bot commented Oct 16, 2025

Review updated until commit 0327f67

Description

  • Add PERMISSIVE_RESIZE graph support in IdModel

  • Detect incompatible reshape patterns during scheduling

  • Prevent fusion of reshapes with conflicting transformations

  • Enable correct segmentation for incompatible reshapes


Changes walkthrough 📝

Relevant files
Enhancement
id_model.cpp
Add PERMISSIVE_RESIZE graph building                                         

csrc/id_model/id_model.cpp

  • Added buildPermissiveResizeGraph() method to support PERMISSIVE_RESIZE
    mapping mode
  • Initializes graph from PERMISSIVE mode then maps resize input to
    output IDs
  • Propagates mappings regardless of extent, enabling broader ID grouping
  • Integrated into buildGraph() switch for new IdMappingMode
  • +34/-0   
    id_model.h
    Declare PERMISSIVE_RESIZE graph method                                     

    csrc/id_model/id_model.h

  • Declared buildPermissiveResizeGraph() method in IdModel class
  • Documented initialization from PERMISSIVE entries
  • Added support for mapping through resize and indexed domains
  • +5/-0     
    registry_utils.h
    Declare incompatible reshape checker                                         

    csrc/scheduler/registry_utils.h

  • Added declaration of hasIncompatibleReshapes() static method
  • Documented purpose: detect reshape conflicts in fusion
  • Complements existing hasCyclicReshape check
  • +6/-0     
    Bug fix
    registry.cpp
    Add incompatible reshape check in scheduler                           

    csrc/scheduler/registry.cpp

  • Added check for incompatible reshapes in checkCanSchedule
  • Rejects scheduling if hasIncompatibleReshapes() returns true
  • Provides debug feedback via canScheduleRejectReason
  • +7/-0     
    registry_utils.cpp
    Detect incompatible reshape transformations                           

    csrc/scheduler/registry_utils.cpp

  • Implemented hasIncompatibleReshapes() to detect conflicting reshape
    patterns
  • Uses PERMISSIVE_RESIZE graph to find mapped rfactor IDs with different
    transformations
  • Returns true if same ValGroup has different ExprGroups (excluding
    Resize)
  • Prevents invalid fusion by detecting transformation conflicts
  • +57/-14 
    Tests
    test_id_model.cpp
    Test PERMISSIVE_RESIZE graph functionality                             

    tests/cpp/test_id_model.cpp

  • Added PermissiveResizeGraph test to validate new graph behavior
  • Tests mapping of sliced and reshaped tensors with different extents
  • Verifies correct propagation through equivalent transformations
  • Confirms ID mapping based on operation equivalence, not extent
  • +58/-5   
    test_reshape.cpp
    Add reshape compatibility tests                                                   

    tests/cpp/test_reshape.cpp

  • Added multiple tests for compatible/incompatible reshape scenarios
  • Tests both same and different disjoint set cases
  • Validates segmentation behavior for incompatible patterns
  • Includes multi-step reshape and merge cases
  • +367/-4 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function hasIncompatibleReshapes uses a newly constructed IdModel and builds the permissive resize graph, but it may not be reusing existing IdModel instances which could affect performance or consistency. Consider reusing IdModel when possible as hinted by the TODO comment.

    IdModel id_model(fusion);
    const auto& permissive_resize_graph = id_model.buildPermissiveResizeGraph();
    
    for (const ValGroup& val_group :
         permissive_resize_graph.disjointValSets().disjointSets()) {
      // Collect all rfactor IDs in this val group
      // Check for consistency if there are at least 2 rfactor IDs
      std::vector<IterDomain*> rfactor_ids;
      for (Val* val : *val_group) {
        auto id = val->as<IterDomain>();
        if (id->isRFactorProduct()) {
          rfactor_ids.push_back(id);
        }
      }
      if (rfactor_ids.size() < 2) {
        continue;
      }
    
      // For ids in the same val group, their usages should be in the same expr
      // group. Resize ops are skipped since they are not propagated during replay
      std::optional<ExprGroup> common_use_group;
      for (auto id : rfactor_ids) {
        if (!permissive_resize_graph.hasUses(
                permissive_resize_graph.toGroup(id))) {
          continue;
        }
        const auto& use_groups =
            permissive_resize_graph.getUses(permissive_resize_graph.toGroup(id));
        for (const auto& use_group : use_groups) {
          if (std::any_of(use_group->begin(), use_group->end(), [](Expr* expr) {
                return expr->isA<Resize>();
              })) {
            continue;
          }
          if (!common_use_group.has_value()) {
            common_use_group = use_group;
          } else if (common_use_group.value() != use_group) {
            return true;
          }
        }
      }
    }
    
    return false;
    Debug Code

    The use of std::cout for logging inside buildPermissiveResizeGraph may not be appropriate for production code and should be replaced with proper logging mechanisms.

    std::cout << "mapping "
              << id->definition()->as<Resize>()->in()->toString()
              << " to " << id->toString() << std::endl;

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant