Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Oct 23, 2025

Array<__e2m1, size> requires size % 2 == 0 since it's a packed dtype.

Existing vectorization heuristics in pointwise scheduler didn't consider this when capping its vectorization_factor. This PR is just a quick patch to add a lower bound of 2 when packed dtype is seen on vectorized inputs/outputs.

@github-actions
Copy link

github-actions bot commented Oct 23, 2025

Review updated until commit d88ee26

Description

  • Fix vectorization factor for sub-byte dtypes like fp4

  • Ensure vectorization factor is at least 2 for packed types

  • Add test cases for fp4 vectorization in pointwise scheduler

  • Handle fp4 dtype comparison in test utilities via int8 view


Changes walkthrough 📝

Relevant files
Bug fix
pointwise.cpp
Enforce min vectorization factor for sub-byte dtypes         

csrc/scheduler/pointwise.cpp

  • Detect sub-byte dtypes in vectorizable inputs/outputs
  • Set minimum vectorization factor to 2 if sub-byte type is present
  • Prevent illegal array creation with sub-byte length
  • Update heuristics to respect packed dtype alignment
  • +7/-1     
    Tests
    test_narrow_precision.py
    Add fp4 vectorization test cases                                                 

    tests/python/direct/test_narrow_precision.py

  • Add test for fp4 vectorization in pointwise fusion
  • Test with bfloat16 and float input dtypes
  • Validate correct output shape and values
  • Skip test on non-Blackwell architectures
  • +33/-0   
    utils.py
    Support fp4 tensor comparison in test utils                           

    tests/python/direct_utils/utils.py

  • Add support for comparing float4_e2m1fn_x2 tensors
  • View fp4 tensors as int8 for comparison
  • Preserve complex dtype handling logic
  • Avoid float64 cast for packed fp4 types
  • +8/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Potential Incorrect Vectorization Bound

    The condition has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor may not correctly enforce the requirement that vectorization factor must be even for packed sub-byte types like __e2m1. Using std::max(2l, ...) allows odd values (e.g., 3), which could still result in misaligned or invalid array packing when size is odd. Consider ensuring the vectorization factor is even when dealing with sub-byte types.

    params->vectorization_factor = std::min(
        has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor,
        vectorize_helper::getVectorizationFactor(
    Test Coverage Limitation

    The test test_fp4_vectorization uses a fixed input size (4, 8) that results in a total of 32 elements after reshape, which is divisible by 2. Additional test cases with shapes that could expose edge cases (e.g., odd number of packed elements) should be considered to validate correctness under various vectorization factors.

    def test_fp4_vectorization(
        nvfuser_direct_test,
        dtype,
    ):
        inputs = [
            torch.ones(4, 8, dtype=dtype, device="cuda"),
            torch.ones(4, dtype=dtype, device="cuda"),
        ]
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            T0 = fd.from_pytorch(inputs[0])
            T1 = fd.from_pytorch(inputs[1])
            T2 = fd.ops.cast(T0, DataType.Float)
            cast_T1 = fd.ops.cast(T1, DataType.Float)
            broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True])
            T3 = fd.ops.div(T2, broadcast_T1)
            T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn)
            T5 = fd.ops.reshape(T4, [32])
            fd.add_output(T5)
    
        o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
    
        ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1)
    Sub-Byte Detection Logic

    The logic to detect sub-byte types relies solely on dataTypeSizeBit < 8, which may include future sub-byte types beyond Float4_e2m1fn. However, not all sub-byte types necessarily require even-length arrays. The heuristic should ideally be specific to packed types like Float4_e2m1fn rather than all sub-byte types to avoid over-constraining vectorization.

    bool has_sub_byte = false;
    for (auto inp : vectorizable_inputs_outputs_entry.get()) {
      has_sub_byte |= dataTypeSizeBit(inp->getDataType().value()) < 8;
      max_dtype_size_bit_for_vectorization = std::max(
          max_dtype_size_bit_for_vectorization,
          dataTypeSizeBit(inp->getDataType().value(), index_type));

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 marked this pull request as ready for review October 23, 2025 22:20
    @jjsjann123 jjsjann123 requested review from liqiangxl, protonu, rdspring1 and zasdfgbnm and removed request for liqiangxl October 23, 2025 22:20
    # However, casting complex values to real discards the imaginary
    # part, so skip complex dtypes.
    if not ref_out.dtype.is_complex:
    if ref_out.dtype == torch.float4_e2m1fn_x2:
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @rdspring1 probably not the cleanest thing, just a naive python patch

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    // sub-byte length.
    params->vectorization_factor = std::min(
    max_vect_factor,
    has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor,
    Copy link
    Collaborator

    @protonu protonu Oct 24, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks - I need something like this as well. I thought I may need this to have a minimum of 4/8(bf16).

    I guess these are all computed at compile time so there's no information about alignment yet. Can increasing the vectorization width (or having any vectorization at all) lead to incorrect behavior?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This shouldn't have caused more issue than what would have already been there.

    This only modifies heuristics on the lower bound of max_vect_factor when fp4 presents. Actual alignment is constrained on top of this with the outer std::min.
    If we cannot do vectorization at the proper width, it would still not do that and codegen would fail later. (For fp4 types, we'll hit assert in the runtime function).

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do you think we should assert on the params->vectorization_factor here in the case that we have a packed fp4 type and we aren't able to vectorize to 2 or more?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    ^ I would prefer not.

    generally speaking, earlier error seems like a good idea. But vectorization should not be the answer.
    It's really the inline that's failing us. i.e. we don't necessarily need the fp4 TV to be vectorized if we have correct inline. We really only needed to have a packed storage.

    I don't know how to fix that properly yet and I wanted it to work for now. That's why I was calling it a hacky patch. 😰


    # Check that the values of all outputs match
    for ref_out, cap_out in zip(reference_outputs, captured_outputs):
    # torch.allclose does not work with fp8 datatype, so cast to fp64.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Might be useful to add to the comment here about the packed fp4.

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    Looks like CI is broken again. But test has cleared earlier. (only added comment since [2917c9d].

    I'll merge this as-is after build clears.

    @jjsjann123 jjsjann123 changed the title Hacky patch to ensure nvfp4 for pointwise scheduler Hacky patch to support correct vectorization factor of nvfp4 for pointwise scheduler Oct 27, 2025
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 merged commit af1d3fc into main Oct 28, 2025
    64 of 66 checks passed
    @jjsjann123 jjsjann123 deleted the jj/fp4_subbyte_type_pointwise_scheduler branch October 28, 2025 06:17
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants