Skip to content

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Sep 5, 2025

@wujingyue so I don't forget

@wujingyue wujingyue self-assigned this Sep 5, 2025
Copy link

github-actions bot commented Sep 5, 2025

Description

  • Reproduce stride mismatch issue with DTensor and nvFuser

  • Add test case for non-contiguous DTensor with replication

  • Validate contiguity computation in multi-device scheduling

  • Demonstrate execution failure for stride-contiguity mismatch


Changes walkthrough 📝

Relevant files
Bug fix
repro.py
Add repro for DTensor-nvfuser stride issue                             

repro.py

  • Added repro script for DTensor-nvfuser stride mismatch
  • Created strided tensor and converted to replicated DTensor
  • Implemented multidevice schedule with device mesh setup
  • Execution fails with stride-contiguity mismatch error
  • +107/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The PR introduces a repro for a stride mismatch error in nvFuser when executing with DTensor inputs. The error occurs during execution, indicating a potential issue in how contiguity and stride information are handled in the fusion definition or scheduling logic.

    # Currently this fails.
    # RuntimeError: Stride mismatch with contiguity info.  allocation domain: iS2{2}, iS0{5}, iS1{6}: sizes: [2, 5, 6]: strides: [1, 12, 2]; contiguity: f, f, t; dim: 2; expected stride: 1; actual stride: 2
    actual = nvfd.execute_with_dtensors(fd, [in_dtensor])
    Missing Test Validation

    The PR does not include any new tests or updates to existing tests to validate the behavior or fix for the reported stride mismatch issue. This reduces confidence in the correctness and robustness of the solution.

    torch.testing.assert_close(actual[0], expected)
    Contiguity Handling

    The computation of contiguity and stride order using compute_contiguity may not correctly reflect the actual memory layout of the DTensor's local tensor, especially when the global tensor has non-contiguous strides. This could lead to incorrect fusion scheduling and execution errors.

    contiguity, stride_order = compute_contiguity(in_dtensor.shape, in_dtensor.stride())
    
    print(contiguity, "CONTIGUITY")
    print(stride_order, "STRIDE ORDER")

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant