Skip to content

Conversation

zasdfgbnm
Copy link
Collaborator

No description provided.

@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm
Copy link
Collaborator Author

Fixes #4888

Copy link

github-actions bot commented Aug 27, 2025

Review updated until commit 5f89380

Description

  • Use ATen ops for ExprEval segment output shapes and strides

  • Compute and update contiguity for non-broadcast and non-reduction dims

  • Fix contiguity validation by passing TensorView to validation function

  • Add helper to infer contiguity from sizes and strides


Changes walkthrough 📝

Relevant files
Bug fix
nodes.cpp
Improve contiguity validation logic                                           

csrc/ir/nodes.cpp

  • Added check for broadcast or reduction domains when setting contiguity
  • Improved error message clarity for contiguity validation
  • +2/-1     
    tensor_metadata.cpp
    Fix contiguity validation via TensorView                                 

    csrc/tensor_metadata.cpp

  • Updated validateAllocationSizesAndStrides to accept TensorView instead
    of contiguity vector
  • Fetch contiguity directly from TensorView in validation
  • Improved error message with TensorView name
  • +6/-3     
    Enhancement
    fusion_cache_utils.cpp
    Add contiguity computation and update logic                           

    csrc/runtime/fusion_cache_utils.cpp

  • Added _computeContiguity helper to infer contiguity from sizes and
    strides
  • Enhanced updateWithSegmentOutputs to conditionally update contiguity
  • Integrated ATen tensor metadata for ExprEval segment outputs
  • Added debug prints for tensor properties during contiguity update
  • +75/-1   
    fusion_kernel_runtime.cpp
    Use ATen ops for ExprEval output inference                             

    csrc/runtime/fusion_kernel_runtime.cpp

  • Use ExpressionEvaluator to compute output sizes for ExprEval segments
  • Route ExprEval outputs through ATen-style stride computation
  • Pass update_contiguity flag when updating segment outputs
  • +22/-5   
    fusion_cache_utils.h
    Extend updateWithSegmentOutputs signature                               

    csrc/runtime/fusion_cache_utils.h

  • Added optional update_contiguity parameter to updateWithSegmentOutputs
  • Defaulted new parameter to false for backward compatibility
  • +2/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Debug Logging

    The code includes uncommented std::cout statements for debugging, which should not be present in production code.

    std::cout << tv->toString() << std::endl;
    tv->printTransforms();
    const at::Tensor& tensor = group_runtime_outputs[group_out_i].as<at::Tensor>();
    const std::vector<int64_t> sizes = tensor.sizes().vec();
    const std::vector<int64_t> strides = tensor.strides().vec();
    // const auto [sizes, strides] = inferAndValidateAllocationSizesAndStrides(tensor, tv, ExpressionEvaluator());
    std::vector<std::optional<bool>> contiguity = _computeContiguity(sizes, strides);
    std::cout << "sizes: " << sizes << std::endl;
    std::cout << "strides: " << strides << std::endl;
    // std::cout << "contiguity: " << contiguity << std::endl;
    Possible Issue

    The logic for handling reduction dimensions in contiguity computation may incorrectly assume the index increment; the commented-out line suggests uncertainty in the implementation.

    int64_t index_with_reduction = 0;
    for (const auto id : tv->domain()->maybeAllocation()) {
      if (id->isReduction()) {
        contiguity_with_reduction.push_back(std::nullopt);
      } else {
        contiguity_with_reduction.push_back(contiguity[index_with_reduction++]);
        // (void)index_with_reduction;
        // contiguity_with_reduction.push_back(false);
      }
    }
    Function Signature Change

    The function validateAllocationSizesAndStrides now takes a TensorView* instead of a contiguity vector, which may affect how contiguity is validated and could introduce inconsistencies if not handled properly across call sites.

    void validateAllocationSizesAndStrides(
        const std::vector<IterDomain*>& alloc_dom,
        TensorView* tv,
        c10::IntArrayRef sizes,
        c10::IntArrayRef strides) {
      const std::vector<std::optional<bool>>& contiguity = tv->getContiguity();
      NVF_ERROR(alloc_dom.size() == contiguity.size());

    zasdfgbnm added a commit that referenced this pull request Sep 2, 2025
    In preparing for the fix of #4888,
    I am working on #5082, which
    requires the use of `Expr::evaluate` on meta tensors to infer shape and
    strides of fusion segments selected to be scheduled by the `ExprEval`
    scheduler. As a consequence of this change, all the `Expr::evaluate`
    functions should support meta device, and the returned output tensor's
    shape and stride must match that on device type CUDA.
    
    According to https://docs.pytorch.org/docs/stable/meta.html
    > In some cases, not all device types (e.g., CPU and CUDA) have exactly
    the same output metadata for an operation; we typically prefer
    representing the CUDA behavior faithfully in this situation.
    
    It is generally safe to assume that we can use device type meta to infer
    shapes and strides of device type CUDA. But unfortunately, not all
    operators implement meta device, and
    `at::_scaled_dot_product_flash_attention` is such an example.
    
    In this PR, I am adding my own `at::_scaled_dot_product_flash_attention`
    implementation on meta devices.
    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant