Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Aug 28, 2025

In preparing for the fix of #4888, I am working on #5082, which requires the use of Expr::evaluate on meta tensors to infer shape and strides of fusion segments selected to be scheduled by the ExprEval scheduler. As a consequence of this change, all the Expr::evaluate functions should support meta device, and the returned output tensor's shape and stride must match that on device type CUDA.

According to https://docs.pytorch.org/docs/stable/meta.html

In some cases, not all device types (e.g., CPU and CUDA) have exactly the same output metadata for an operation; we typically prefer representing the CUDA behavior faithfully in this situation.

It is generally safe to assume that we can use device type meta to infer shapes and strides of device type CUDA. But unfortunately, not all operators implement meta device, and at::_scaled_dot_product_flash_attention is such an example.

In this PR, I am adding my own at::_scaled_dot_product_flash_attention implementation on meta devices.

@github-actions
Copy link

github-actions bot commented Aug 28, 2025

Review updated until commit 9c4e1cb

Description

  • Add meta device support for SDPA forward op

  • Implement meta tensor shape/stride inference

  • Extend test suite for meta device validation

  • Ensure CUDA semantics in meta implementation


Changes walkthrough 📝

Relevant files
Enhancement
nodes.cpp
Implement SDPA meta device evaluation                                       

csrc/ir/nodes.cpp

  • Added spda_meta::_scaled_dot_product_flash_attention_meta for meta
    device support
  • Modified SdpaFwdOp::evaluate to dispatch to meta implementation when
    input is meta tensor
  • Ensured output tensor shapes and strides match CUDA semantics
  • +43/-9   
    Tests
    test_sdpa_node.cpp
    Add meta device validation in SDPA tests                                 

    tests/cpp/test_sdpa_node.cpp

  • Introduced MetaSdpaOut tuple for meta tensor outputs
  • Updated validateSdpaFwdOutputs to accept and compare meta tensor
    outputs
  • Added ExpressionEvaluator tests with meta tensors across multiple test
    cases
  • Verified shape and stride consistency between device and meta outputs
  • +76/-6   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Incomplete Meta Implementation

    The meta implementation of _scaled_dot_product_flash_attention_meta returns placeholder tensors for several outputs (e.g., cum_seq_q, cum_seq_k, philox_seed, philox_offset, debug_attn_mask) as empty tensors or zero-initialized values. This may not accurately reflect the actual CUDA behavior, potentially leading to incorrect shape or stride inference in downstream fusion segments.

    at::Tensor(),
    at::Tensor(),
    c10::SymInt(seqlen_q),
    c10::SymInt(seqlen_q),
    at::Tensor(),
    at::Tensor(),
    at::Tensor());
    Missing Scale Handling

    The meta implementation does not use the scale parameter, which is passed to the real CUDA kernel. This could result in divergence between meta and CUDA behaviors when scale affects output tensor properties, even if indirectly.

    _scaled_dot_product_flash_attention_meta(const at::Tensor& query) {
      const auto sizes = query.sizes();
      const int batch_size = sizes[0];
      int num_heads = sizes[1];
      int seqlen_q = sizes[2];
      auto logsumexp = at::empty(
          {batch_size, num_heads, seqlen_q}, query.options().dtype(at::kFloat));
      return std::make_tuple(
          query,
          logsumexp,
          at::Tensor(),
          at::Tensor(),
          c10::SymInt(seqlen_q),
          c10::SymInt(seqlen_q),
          at::Tensor(),
          at::Tensor(),
          at::Tensor());
    Limited Test Coverage for Meta Outputs

    The test only validates sizes and strides of the main output tensors (attn and log_sumexp) on meta device. Other returned values like logsumexp, cum_seq_q, cum_seq_k, and seeds are not fully validated for correctness under meta execution, leaving potential issues undetected.

    auto [attn_meta, log_sumexp_meta] = aten_out_meta;
    EXPECT_EQ(attn.sizes(), attn_meta.sizes());
    EXPECT_EQ(log_sumexp.sizes(), log_sumexp_meta.sizes());
    EXPECT_EQ(attn.strides(), attn_meta.strides());
    EXPECT_EQ(log_sumexp.strides(), log_sumexp_meta.strides());

    @zasdfgbnm zasdfgbnm marked this pull request as ready for review August 29, 2025 04:15
    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm requested review from jjsjann123 and naoyam August 29, 2025 04:25
    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM overall. minor nitpick about the shape propagation.

    _scaled_dot_product_flash_attention_meta(const at::Tensor& query) {
    // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
    // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
    at::Tensor q_t = query.transpose(1, 2);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    super nitpick. pytorch kernel implementation uses transposes because they are feeding this to a kernel. I think we can directly retrieve the shape and call explicit empty or empty_strided, which is probably easier to read.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Ha! Good catch! I spent a lot of time simplifying PyTorch's implementation, but didn't realize that it was just naively return query.

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm merged commit f41524b into main Sep 2, 2025
    47 of 51 checks passed
    @zasdfgbnm zasdfgbnm deleted the spda-meta branch September 2, 2025 20:43
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants