Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4890,6 +4890,39 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

namespace spda_meta {

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_flash_attention_meta(const at::Tensor& query) {
const auto sizes = query.sizes();
const int batch_size = sizes[0];
int num_heads = sizes[1];
int seqlen_q = sizes[2];
auto logsumexp = at::empty(
{batch_size, num_heads, seqlen_q}, query.options().dtype(at::kFloat));
return std::make_tuple(
query,
logsumexp,
at::Tensor(),
at::Tensor(),
c10::SymInt(seqlen_q),
c10::SymInt(seqlen_q),
at::Tensor(),
at::Tensor(),
at::Tensor());
}

} // namespace spda_meta

std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
Expand Down Expand Up @@ -4961,15 +4994,16 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
key_seq_len,
philox_seed,
philox_offset,
debug_attn_mask] =
at::_scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p,
is_causal,
/*return_debug_mask=*/false,
scale);
debug_attn_mask] = query.is_meta()
? spda_meta::_scaled_dot_product_flash_attention_meta(query)
: at::_scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p,
is_causal,
/*return_debug_mask=*/false,
scale);

// If the inputs were padded, slice the output to restore the original
// size
Expand Down
82 changes: 76 additions & 6 deletions tests/cpp/test_sdpa_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ using AtenSdpaOut = std::tuple<
at::Tensor,
at::Tensor,
at::Tensor>;

using MetaSdpaOut = std::tuple<at::Tensor, at::Tensor>;

auto validateSdpaFwdOutputs = [](KernelArgumentHolder nvf_out,
AtenSdpaOut aten_out) {
AtenSdpaOut aten_out,
MetaSdpaOut aten_out_meta) {
auto
[attn,
log_sumexp,
Expand All @@ -64,6 +68,12 @@ auto validateSdpaFwdOutputs = [](KernelArgumentHolder nvf_out,
// garbage values for this case, so we skip validating those values.
NVF_CHECK(at::allclose(nvf_out[0].as<at::Tensor>(), attn));
NVF_CHECK(at::allclose(nvf_out[1].as<at::Tensor>(), log_sumexp));

auto [attn_meta, log_sumexp_meta] = aten_out_meta;
EXPECT_EQ(attn.sizes(), attn_meta.sizes());
EXPECT_EQ(log_sumexp.sizes(), log_sumexp_meta.sizes());
EXPECT_EQ(attn.strides(), attn_meta.strides());
EXPECT_EQ(log_sumexp.strides(), log_sumexp_meta.strides());
};

// Check SDPAFwdOp mapping in IdModel and ComputeAtMap.
Expand Down Expand Up @@ -258,7 +268,17 @@ TEST_F(SDPATest, NonCausalAttnConcrete) {

FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs({q, k, v});
validateSdpaFwdOutputs(nvf_out, aten_out);

ExpressionEvaluator ee;
ee.bind(executor_cache.fusion()->inputs().at(0), q.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(1), k.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(2), v.to(at::kMeta));
auto a = ee.evaluate(executor_cache.fusion()->outputs().at(0));
MetaSdpaOut aten_out_meta = {
ee.evaluate(executor_cache.fusion()->outputs().at(0)).as<at::Tensor>(),
ee.evaluate(executor_cache.fusion()->outputs().at(1)).as<at::Tensor>(),
};
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

TEST_F(SDPATest, NonCausalAttnSymbolic) {
Expand Down Expand Up @@ -305,7 +325,16 @@ TEST_F(SDPATest, NonCausalAttnSymbolic) {

FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs({q, k, v});
validateSdpaFwdOutputs(nvf_out, aten_out);

ExpressionEvaluator ee;
ee.bind(executor_cache.fusion()->inputs().at(0), q.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(1), k.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(2), v.to(at::kMeta));
MetaSdpaOut aten_out_meta = {
ee.evaluate(executor_cache.fusion()->outputs().at(0)).as<at::Tensor>(),
ee.evaluate(executor_cache.fusion()->outputs().at(1)).as<at::Tensor>(),
};
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

TEST_F(SDPATest, CausalAttn) {
Expand Down Expand Up @@ -351,7 +380,16 @@ TEST_F(SDPATest, CausalAttn) {

FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs({q, k, v});
validateSdpaFwdOutputs(nvf_out, aten_out);

ExpressionEvaluator ee;
ee.bind(executor_cache.fusion()->inputs().at(0), q.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(1), k.to(at::kMeta));
ee.bind(executor_cache.fusion()->inputs().at(2), v.to(at::kMeta));
MetaSdpaOut aten_out_meta = {
ee.evaluate(executor_cache.fusion()->outputs().at(0)).as<at::Tensor>(),
ee.evaluate(executor_cache.fusion()->outputs().at(1)).as<at::Tensor>(),
};
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

TEST_F(SDPATest, PairwiseLogicalDomainMap) {
Expand Down Expand Up @@ -828,7 +866,23 @@ TEST_F(SDPATest, Sharded_SdpaFwd) {
FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs(
{q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)});
validateSdpaFwdOutputs(nvf_out, aten_out);

ExpressionEvaluator ee;
ee.bind(
executor_cache.fusion()->inputs().at(0), q.to(at::kMeta).unsqueeze(0));
ee.bind(
executor_cache.fusion()->inputs().at(1), k.to(at::kMeta).unsqueeze(0));
ee.bind(
executor_cache.fusion()->inputs().at(2), v.to(at::kMeta).unsqueeze(0));
MetaSdpaOut aten_out_meta = {
ee.evaluate(executor_cache.fusion()->outputs().at(0))
.as<at::Tensor>()
.squeeze(0),
ee.evaluate(executor_cache.fusion()->outputs().at(1))
.as<at::Tensor>()
.squeeze(0),
};
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

// TODO: Remove/update when https://github.com/NVIDIA/Fuser/issues/2563 is
Expand Down Expand Up @@ -1019,7 +1073,23 @@ TEST_F(SDPATest, ComputeAt) {
FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs(
{q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)});
validateSdpaFwdOutputs(nvf_out, aten_out);

ExpressionEvaluator ee;
ee.bind(
executor_cache.fusion()->inputs().at(0), q.to(at::kMeta).unsqueeze(0));
ee.bind(
executor_cache.fusion()->inputs().at(1), k.to(at::kMeta).unsqueeze(0));
ee.bind(
executor_cache.fusion()->inputs().at(2), v.to(at::kMeta).unsqueeze(0));
MetaSdpaOut aten_out_meta = {
ee.evaluate(executor_cache.fusion()->outputs().at(0))
.as<at::Tensor>()
.squeeze(0),
ee.evaluate(executor_cache.fusion()->outputs().at(1))
.as<at::Tensor>()
.squeeze(0),
};
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

} // namespace nvfuser