Skip to content

Conversation

@rdspring1 rdspring1 added Python API Issues related to the Python API Direct Bindings Python extension with direct mapping to NvFuser CPP objects. labels Jul 28, 2025
Copy link

github-actions bot commented Jul 28, 2025

Review updated until commit 1059382

Description

  • Added full operation to Python direct bindings

  • Implemented shape handling for list/tuple inputs

  • Added test case for scalar-only fusion scenarios


Changes walkthrough 📝

Relevant files
Enhancement
ops.cpp
Implement full op with shape flexibility                                 

python/python_direct/ops.cpp

  • Added template function full_op_fn for shape conversion
  • Bound full operation with list/tuple shape support
  • Added docstring for full operation
  • +40/-0   
    python_translate.cpp
    Map FullOp to Python frontend                                                       

    python/python_direct/python_translate.cpp

  • Added FullOp handler for Python frontend mapping
  • Dispatched fill value creation for dynamic support
  • Generated kwargs for shape, fill_value, and dtype
  • +19/-0   
    opinfos.py
    Enable direct bindings for full op                                             

    tests/python/opinfo/opinfos.py

  • Updated full_opinfo to support direct bindings
  • Added supports_direct_bindings=True flag
  • +1/-0     
    Tests
    test_python_frontend.py
    Test full op with scalar inputs                                                   

    tests/python/direct/test_python_frontend.py

  • Added test case test_scalar_only_inputs
  • Verified full op with scalar addition in fusion
  • Compared NVFuser output with PyTorch equivalent
  • +20/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Shape Handling

    The SequenceAsVector function converts Python sequences to a vector of Val*. Ensure it properly handles edge cases like empty shapes, non-integer elements, or mixed types (e.g., tuple of lists). Testing for dynamic shapes (e.g., symbolic dimensions) is also critical.

    template <class ShapeType>
    TensorView* full_op_fn(
        ShapeType generic_output_shape,
        Val* fill_value,
        PrimDataType dtype) {
      std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
      return full(output_shape, fill_value, dtype);
    }
    Fill Value Validity

    The dispatch(fop->getFillValue()) call may not validate dynamic fill values (e.g., symbolic scalars). Confirm that all valid input types for fill_value are supported and errors are handled gracefully.

    void handle(const FullOp* fop) final {
      NVF_ERROR(fop != nullptr);
      TensorView* out_tv = fop->output(0)->as<TensorView>();
      visited_vals_.insert(out_tv);
    
      // Fill value can be dynamic so create it
      dispatch(fop->getFillValue());
    
      static const std::vector<std::string> argument_names = {
          "shape", "fill_value", "dtype"};
      printer_.generateKwargsOperation(
          "fd.ops.full",
          std::make_tuple(),
          argument_names,
          std::make_tuple(getShape(out_tv), fop->getFillValue(), out_tv->dtype()),
          {out_tv});
    }
    Test Coverage

    The test test_scalar_only_inputs uses a fixed shape [2,2]. Consider adding tests for dynamic shapes, varying dtypes, and invalid inputs (e.g., negative dimensions) to ensure robustness.

    def test_scalar_only_inputs(nvfuser_direct_test):
        # We don't allow scalar outputs, currently,
        # so a tensor has to be returned
        def fusion_func(fd: FusionDefinition):
            s0 = fd.define_scalar()
            s1 = fd.define_scalar()
            s2 = fd.ops.add(s0, s1)
            c0 = fd.define_scalar(1.0, DataType.Float)
            t3 = fd.ops.full(shape=[2, 2], fill_value=c0, dtype=DataType.Float)
            t4 = fd.ops.mul(t3, s2)
            fd.add_output(t4)
    
        with FusionDefinition() as fd:
            fusion_func(fd)
    
        nvf_out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, [2.0, 3.0])
        eager_out = torch.full([2, 2], 1.0) * 5.0
        nvfuser_direct_test.assertEqual(eager_out, nvf_out[0])

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM with minor nitpick on comment.

    @rdspring1
    Copy link
    Collaborator Author

    !build

    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    Base automatically changed from direct_tp15 to main August 3, 2025 19:52
    @rdspring1
    Copy link
    Collaborator Author

    !build

    @rdspring1 rdspring1 merged commit 2b38ad4 into main Aug 3, 2025
    17 checks passed
    @rdspring1 rdspring1 deleted the direct_tp16 branch August 3, 2025 20:46
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 3, 2025
    rdspring1 added a commit that referenced this pull request Aug 4, 2025
    rdspring1 added a commit that referenced this pull request Aug 4, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects. Python API Issues related to the Python API

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants