|
494 | 494 | py::arg("rng_state") = std::nullopt, \ |
495 | 495 | py::arg("gen") = std::nullopt); |
496 | 496 |
|
497 | | -#define MHA_VARLEN_BWD_ASM_PYBIND \ |
498 | | - m.def("fmha_v3_varlen_bwd", \ |
499 | | - &aiter::torch_itfs::fmha_v3_varlen_bwd, \ |
500 | | - py::arg("dout"), \ |
501 | | - py::arg("q"), \ |
502 | | - py::arg("k"), \ |
503 | | - py::arg("v"), \ |
504 | | - py::arg("out"), \ |
505 | | - py::arg("softmax_lse"), \ |
506 | | - py::arg("cu_seqlens_q"), \ |
507 | | - py::arg("cu_seqlens_k"), \ |
508 | | - py::arg("max_seqlen_q"), \ |
509 | | - py::arg("max_seqlen_k"), \ |
510 | | - py::arg("dropout_p"), \ |
511 | | - py::arg("softmax_scale"), \ |
512 | | - py::arg("zero_tensors"), \ |
513 | | - py::arg("is_causal"), \ |
514 | | - py::arg("window_size_left"), \ |
515 | | - py::arg("window_size_right"), \ |
516 | | - py::arg("deterministic"), \ |
517 | | - py::arg("is_v3_atomic_fp32"), \ |
518 | | - py::arg("how_v3_bf16_cvt"), \ |
519 | | - py::arg("dq") = std::nullopt, \ |
520 | | - py::arg("dk") = std::nullopt, \ |
521 | | - py::arg("dv") = std::nullopt, \ |
522 | | - py::arg("alibi_slopes") = std::nullopt, \ |
523 | | - py::arg("rng_state") = std::nullopt, \ |
524 | | - py::arg("gen") = std::nullopt); |
| 497 | +#define MHA_VARLEN_BWD_ASM_PYBIND \ |
| 498 | + m.def("fmha_v3_varlen_bwd", \ |
| 499 | + &aiter::torch_itfs::fmha_v3_varlen_bwd, \ |
| 500 | + py::arg("dout"), \ |
| 501 | + py::arg("q"), \ |
| 502 | + py::arg("k"), \ |
| 503 | + py::arg("v"), \ |
| 504 | + py::arg("out"), \ |
| 505 | + py::arg("softmax_lse"), \ |
| 506 | + py::arg("cu_seqlens_q"), \ |
| 507 | + py::arg("cu_seqlens_k"), \ |
| 508 | + py::arg("max_seqlen_q"), \ |
| 509 | + py::arg("max_seqlen_k"), \ |
| 510 | + py::arg("dropout_p"), \ |
| 511 | + py::arg("softmax_scale"), \ |
| 512 | + py::arg("zero_tensors"), \ |
| 513 | + py::arg("is_causal"), \ |
| 514 | + py::arg("window_size_left"), \ |
| 515 | + py::arg("window_size_right"), \ |
| 516 | + py::arg("deterministic"), \ |
| 517 | + py::arg("is_v3_atomic_fp32"), \ |
| 518 | + py::arg("how_v3_bf16_cvt"), \ |
| 519 | + py::arg("dq") = std::nullopt, \ |
| 520 | + py::arg("dk") = std::nullopt, \ |
| 521 | + py::arg("dv") = std::nullopt, \ |
| 522 | + py::arg("alibi_slopes") = std::nullopt, \ |
| 523 | + py::arg("rng_state") = std::nullopt, \ |
| 524 | + py::arg("gen") = std::nullopt, \ |
| 525 | + py::arg("cu_seqlens_q_padded") = std::nullopt, \ |
| 526 | + py::arg("cu_seqlens_k_padded") = std::nullopt); |
525 | 527 |
|
526 | 528 | #define MHA_BWD_PYBIND \ |
527 | 529 | m.def("mha_bwd", \ |
|
612 | 614 | py::arg("alibi_slopes") = std::nullopt, \ |
613 | 615 | py::arg("gen") = std::nullopt); |
614 | 616 |
|
615 | | -#define MHA_VARLEN_BWD_PYBIND \ |
616 | | - m.def("mha_varlen_bwd", \ |
617 | | - &aiter::torch_itfs::mha_varlen_bwd, \ |
618 | | - py::arg("dout"), \ |
619 | | - py::arg("q"), \ |
620 | | - py::arg("k"), \ |
621 | | - py::arg("v"), \ |
622 | | - py::arg("out"), \ |
623 | | - py::arg("softmax_lse"), \ |
624 | | - py::arg("cu_seqlens_q"), \ |
625 | | - py::arg("cu_seqlens_k"), \ |
626 | | - py::arg("max_seqlen_q"), \ |
627 | | - py::arg("max_seqlen_k"), \ |
628 | | - py::arg("dropout_p"), \ |
629 | | - py::arg("softmax_scale"), \ |
630 | | - py::arg("zero_tensors"), \ |
631 | | - py::arg("is_causal"), \ |
632 | | - py::arg("window_size_left"), \ |
633 | | - py::arg("window_size_right"), \ |
634 | | - py::arg("deterministic"), \ |
635 | | - py::arg("dq") = std::nullopt, \ |
636 | | - py::arg("dk") = std::nullopt, \ |
637 | | - py::arg("dv") = std::nullopt, \ |
638 | | - py::arg("alibi_slopes") = std::nullopt, \ |
639 | | - py::arg("rng_state") = std::nullopt, \ |
640 | | - py::arg("gen") = std::nullopt); |
| 617 | +#define MHA_VARLEN_BWD_PYBIND \ |
| 618 | + m.def("mha_varlen_bwd", \ |
| 619 | + &aiter::torch_itfs::mha_varlen_bwd, \ |
| 620 | + py::arg("dout"), \ |
| 621 | + py::arg("q"), \ |
| 622 | + py::arg("k"), \ |
| 623 | + py::arg("v"), \ |
| 624 | + py::arg("out"), \ |
| 625 | + py::arg("softmax_lse"), \ |
| 626 | + py::arg("cu_seqlens_q"), \ |
| 627 | + py::arg("cu_seqlens_k"), \ |
| 628 | + py::arg("max_seqlen_q"), \ |
| 629 | + py::arg("max_seqlen_k"), \ |
| 630 | + py::arg("dropout_p"), \ |
| 631 | + py::arg("softmax_scale"), \ |
| 632 | + py::arg("zero_tensors"), \ |
| 633 | + py::arg("is_causal"), \ |
| 634 | + py::arg("window_size_left"), \ |
| 635 | + py::arg("window_size_right"), \ |
| 636 | + py::arg("deterministic"), \ |
| 637 | + py::arg("dq") = std::nullopt, \ |
| 638 | + py::arg("dk") = std::nullopt, \ |
| 639 | + py::arg("dv") = std::nullopt, \ |
| 640 | + py::arg("alibi_slopes") = std::nullopt, \ |
| 641 | + py::arg("rng_state") = std::nullopt, \ |
| 642 | + py::arg("gen") = std::nullopt, \ |
| 643 | + py::arg("cu_seqlens_q_padded") = std::nullopt, \ |
| 644 | + py::arg("cu_seqlens_k_padded") = std::nullopt); |
641 | 645 |
|
642 | 646 | #define MOE_CK_2STAGES_PYBIND \ |
643 | 647 | m.def("ck_moe_stage1", \ |
|
0 commit comments