Skip to content

Conversation

@rdspring1 rdspring1 added Python API Issues related to the Python API Direct Bindings Python extension with direct mapping to NvFuser CPP objects. labels Jul 28, 2025
@github-actions
Copy link

Description

  • Added slice operation to direct bindings

  • Implemented slice_fn in ops.cpp

  • Mapped SliceOp to Python frontend in python_translate.cpp

  • Added tests for slice operation in test_python_frontend.py

  • Updated opinfos.py to support direct bindings for slice


Changes walkthrough 📝

Relevant files
Enhancement
ops.cpp
Implement and bind slice operation                                             

python/python_direct/ops.cpp

  • Implemented slice_fn template function
  • Added bindings for slice operation in bindMetadataOps
  • +107/-0 
    python_translate.cpp
    Map SliceOp to Python frontend                                                     

    python/python_direct/python_translate.cpp

    • Mapped SliceOp to Python frontend
    +35/-0   
    opinfos.py
    Update slice_opinfo for direct bindings                                   

    tests/python/opinfo/opinfos.py

    • Updated slice_opinfo to support direct bindings
    +1/-0     
    Tests
    test_python_frontend.py
    Add slice operation tests                                                               

    tests/python/direct/test_python_frontend.py

    • Added test cases for slice operation
    +33/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Stride Limitation

    The current implementation only supports strides of size 1. This is a significant limitation and should be documented or considered for future enhancement.

        start_idx->evaluate().as<int64_t>());
    NVF_CHECK(
        !start_idx->isConstInt() || !end_idx->isConstInt() ||
            end_idx->evaluate().as<int64_t>() >=
                start_idx->evaluate().as<int64_t>(),
        "Slice operation end_indices must be greater than or equal to "
        "start_indices. Start Indices: ",
        start_idx->evaluate().as<int64_t>(),
        " End Indices: ",
        end_idx->evaluate().as<int64_t>());
    NVF_CHECK(
        stride_idx->isConstInt() && stride_idx->evaluate().as<int64_t>() == 1,
        "nvFuser Limitation: All slice operation strides must be of const "
        "size 1.");
    vec_slice.push_back({start_idx, end_idx, stride_idx});
    Test Coverage

    The test cases cover static and dynamic slicing but could benefit from additional test cases, especially for edge cases and error handling.

    x = torch.randn((2, 5, 10), dtype=torch.float32, device="cuda:0")
    
    offset = (0, 1, 2)
    
    def fusion_func(fd: FusionDefinition) -> None:
        T0 = fd.define_tensor(
            shape=[-1, -1, -1],
            contiguity=[True, True, True],
            dtype=DataType.Float,
            is_cpu=False,
            stride_order=[2, 1, 0],
        )
        T1 = fd.ops.slice(
            T0, start_indices=offset, end_indices=(2, 5, 10), strides=(1, 1, 1)
        )
        fd.add_output(T1)
        V_start = list(offset)
        V_end = T0.shape()
        T2 = fd.ops.slice(T0, V_start, V_end)
        fd.add_output(T2)
        dynamic_start = fd.define_vector(3)
        dynamic_end = fd.define_vector(3)
        T3 = fd.ops.slice(T0, dynamic_start, dynamic_end)
        fd.add_output(T3)
    
    inputs = [x, *offset, *x.shape]
    
    nvf_out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)
    for out in nvf_out:
        nvfuser_direct_test.assertTrue(out.allclose(x[:, 1:, 2:]))
    Manual Normalization

    The manual_normalization parameter is set to true in PythonTranslator::handle(const SliceOp* sop). This should be reviewed to ensure it aligns with the intended behavior and is correctly documented.

          isinf,
          R"(
    Element-wise infinity check.
    
    Parameters
    ----------
    x : Val or TensorView

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 marked this pull request as ready for review July 28, 2025 23:47
    @rdspring1
    Copy link
    Collaborator Author

    !build

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    error_input_generator=slice_error_generator,
    reference=jax.lax.slice if JAX_AVAILABLE else None,
    reference_type=ReferenceType.Jax,
    supports_direct_bindings=True,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Neat~

    stride_order=[2, 1, 0],
    )
    T1 = fd.ops.slice(
    T0, start_indices=offset, end_indices=(2, 5, 10), strides=(1, 1, 1)
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick: felt a bit strange that we have offset but not (2, 5, 10) as a named variable.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    offset is used for V_start.

    rdspring1 added a commit that referenced this pull request Jul 30, 2025
    rdspring1 added a commit that referenced this pull request Jul 30, 2025
    Base automatically changed from direct_tp13 to main July 30, 2025 22:12
    @rdspring1
    Copy link
    Collaborator Author

    !build

    @rdspring1 rdspring1 merged commit 186a943 into main Jul 31, 2025
    16 of 17 checks passed
    @rdspring1 rdspring1 deleted the direct_tp14 branch July 31, 2025 01:47
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 4, 2025
    rdspring1 added a commit that referenced this pull request Aug 4, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects. Python API Issues related to the Python API

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants