Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
});

int64_t max_dtype_size_bit_for_vectorization = 0;
// ugly WAR.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give more details rather than just saying it's ugly. Why is it ugly? What should have been done instead?

Also, this variable is only used hundreds of lines below. For better code readability, it should be moved down there.

bool has_sub_byte = false;
for (auto inp : vectorizable_inputs_outputs_entry.get()) {
has_sub_byte |= dataTypeSizeBit(inp->getDataType().value()) < 8;
max_dtype_size_bit_for_vectorization = std::max(
max_dtype_size_bit_for_vectorization,
dataTypeSizeBit(inp->getDataType().value(), index_type));
Expand Down Expand Up @@ -484,8 +487,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
}
}

// If we have sub-byte data types, we wouldn't want to clamp vectorization
// factor to 1, otherwise we could end up with illegal array type with
// sub-byte length.
params->vectorization_factor = std::min(
max_vect_factor,
has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor,
Copy link
Collaborator

@protonu protonu Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I need something like this as well. I thought I may need this to have a minimum of 4/8(bf16).

I guess these are all computed at compile time so there's no information about alignment yet. Can increasing the vectorization width (or having any vectorization at all) lead to incorrect behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't have caused more issue than what would have already been there.

This only modifies heuristics on the lower bound of max_vect_factor when fp4 presents. Actual alignment is constrained on top of this with the outer std::min.
If we cannot do vectorization at the proper width, it would still not do that and codegen would fail later. (For fp4 types, we'll hit assert in the runtime function).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should assert on the params->vectorization_factor here in the case that we have a packed fp4 type and we aren't able to vectorize to 2 or more?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ I would prefer not.

generally speaking, earlier error seems like a good idea. But vectorization should not be the answer.
It's really the inline that's failing us. i.e. we don't necessarily need the fp4 TV to be vectorized if we have correct inline. We really only needed to have a packed storage.

I don't know how to fix that properly yet and I wanted it to work for now. That's why I was calling it a hacky patch. 😰

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand what you meant by "the inline".

vectorize_helper::getVectorizationFactor(
runtime_info,
largest_out,
Expand Down
33 changes: 33 additions & 0 deletions tests/python/direct/test_narrow_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
linear_to_swizzled_128_4,
round_up,
activation_scale_to_nvfp4,
to_fp4,
)

import pytest
Expand Down Expand Up @@ -274,3 +275,35 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
)

assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2)


@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float])
def test_fp4_vectorization(
nvfuser_direct_test,
dtype,
):
inputs = [
torch.ones(4, 8, dtype=dtype, device="cuda"),
torch.ones(4, dtype=dtype, device="cuda"),
]

def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
T0 = fd.from_pytorch(inputs[0])
T1 = fd.from_pytorch(inputs[1])
T2 = fd.ops.cast(T0, DataType.Float)
cast_T1 = fd.ops.cast(T1, DataType.Float)
broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True])
T3 = fd.ops.div(T2, broadcast_T1)
T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn)
T5 = fd.ops.reshape(T4, [32])
fd.add_output(T5)

o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)

ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1)
Comment on lines +307 to +309
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can somebody teach me how validation is done here? IIUC, exec_nvfuser just executes the fusion. There's ref_o. Shouldn't we compare o and ref_o?

Please don't use names like o. Single-letter names should be avoided except established common names like i for loop indices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also, this is for vectorization, but I'm not seeing where the vectorization is validated.

10 changes: 8 additions & 2 deletions tests/python/direct_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None)
# torch.allclose does not work with fp8 datatype, so cast to fp64.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to add to the comment here about the packed fp4.

# However, casting complex values to real discards the imaginary
# part, so skip complex dtypes.
if not ref_out.dtype.is_complex:
# Similarly, packed fp4 dtype cannot be compared neither, we view
# it as int8 and run comparison as-is.
if ref_out.dtype == torch.float4_e2m1fn_x2:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 probably not the cleanest thing, just a naive python patch

ref_out = ref_out.view(torch.int8)
elif not ref_out.dtype.is_complex:
ref_out = ref_out.to(torch.float64)
if not cap_out.dtype.is_complex:
if cap_out.dtype == torch.float4_e2m1fn_x2:
cap_out = cap_out.view(torch.int8)
elif not cap_out.dtype.is_complex:
cap_out = cap_out.to(torch.float64)
if not torch.allclose(ref_out, cap_out, equal_nan=True):
return False
Expand Down
Loading