Skip to content

Conversation

@rdspring1 rdspring1 added Python API Issues related to the Python API Direct Bindings Python extension with direct mapping to NvFuser CPP objects. labels Jul 28, 2025
Copy link

github-actions bot commented Jul 28, 2025

Review updated until commit febeff6

Description

  • Added scatter operation in direct bindings with error handling

  • Mapped ScatterOp to Python frontend for direct binding usage

  • Enabled scatter operation testing with supports_direct_bindings=True


Changes walkthrough 📝

Relevant files
Enhancement
ops.cpp
Add scatter function with error handling                                 

python/python_direct/ops.cpp

  • Added scatter function to bindIndexingOps with parameters arg1, index,
    src, and dim
  • Implemented dimension count validation via NVF_CHECK
  • Added dimension range validation for dim parameter
  • +69/-20 
    python_translate.cpp
    Map ScatterOp to Python frontend                                                 

    python/python_direct/python_translate.cpp

  • Added handle method for ScatterOp class
  • Connected C++ ScatterOp to Python frontend via fd.ops.scatter
  • Used generateKwargsOperation for Python binding generation
  • +14/-0   
    Tests
    opinfos.py
    Enable scatter operation testing                                                 

    tests/python/opinfo/opinfos.py

  • Added scatter_opinfo to shape_ops list
  • Enabled direct bindings support for scatter tests
  • Updated scatter_wrapper with symbolic parameter declarations
  • +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The error message in NVF_CHECK for dimension validation uses incorrect format specifiers when printing tensor dimensions. The variables arg1->nDims(), index->nDims(), and src->nDims() are of type size_t, but the message uses %d which may cause undefined behavior on 64-bit systems.

    NVF_CHECK(
        arg1->nDims() == index->nDims() && arg1->nDims() == src->nDims(),
        "Tensor arguments have different dimensions ",
        arg1->nDims(),
        ", ",
        index->nDims(),
        " and ",
        src->nDims());
    auto num_dims = (int64_t)arg1->nDims();
    NVF_CHECK(
        dim >= -num_dims && dim < num_dims,
        "Tensor arguments have dimension ",
        num_dims,
        " so dim argument must satisfy ",
        -num_dims,
        " <= dim < ",
        num_dims,
        ", but received ",
        dim);
    Missing Test

    The PR adds a new scatter operation but does not include corresponding test cases to validate its correctness or performance. The opinfo in tests/python/opinfo/opinfos.py only enables direct bindings without concrete test implementations.

      ops.def(
          "scatter",
          [](TensorView* arg1, TensorView* index, TensorView* src, int64_t dim)
              -> TensorView* {
            NVF_CHECK(
                arg1->nDims() == index->nDims() && arg1->nDims() == src->nDims(),
                "Tensor arguments have different dimensions ",
                arg1->nDims(),
                ", ",
                index->nDims(),
                " and ",
                src->nDims());
            auto num_dims = (int64_t)arg1->nDims();
            NVF_CHECK(
                dim >= -num_dims && dim < num_dims,
                "Tensor arguments have dimension ",
                num_dims,
                " so dim argument must satisfy ",
                -num_dims,
                " <= dim < ",
                num_dims,
                ", but received ",
                dim);
            return scatter(arg1, dim, index, src);
          },
          py::arg("arg1"),
          py::arg("index"),
          py::arg("src"),
          py::arg("dim"),
          R"(
    Scatter a tensor.
    
    Parameters
    ----------
    arg1 : TensorView
        The tensor to scatter into.
    index : TensorView
        The tensor containing the indices.
    src : TensorView
        The source tensor to scatter from.
    dim : int
        The dimension to scatter along.
    
    Returns
    -------
    TensorView
        The scattered tensor.
    )",
          py::return_value_policy::reference);

    @rdspring1 rdspring1 marked this pull request as ready for review July 28, 2025 23:47
    @rdspring1 rdspring1 requested a review from jjsjann123 July 29, 2025 16:37
    rdspring1 added a commit that referenced this pull request Jul 30, 2025
    rdspring1 added a commit that referenced this pull request Jul 30, 2025
    rdspring1 added a commit that referenced this pull request Jul 31, 2025
    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM overall.

    ArgumentType.Symbolic,
    ArgumentType.Constant,
    ),
    supports_direct_bindings=True,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick: we don't have a functional scatter at this moment. @naoyam is adding limited support, which might not handle all the scatter cases in opinfo. #4742

    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 4, 2025
    Base automatically changed from direct_tp19 to main August 4, 2025 00:51
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 merged commit 517b38e into main Aug 4, 2025
    54 of 55 checks passed
    @rdspring1 rdspring1 deleted the direct_tp20 branch August 4, 2025 05:44
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects. Python API Issues related to the Python API

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants