From 2ee3fba1cd0fb72dc293f5ed73459c80d7509c77 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 22 Oct 2025 15:32:23 -0700 Subject: [PATCH 1/7] bump vec size for fp4 types --- csrc/scheduler/pointwise.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index e9f4fe2e90a..98c26c52ed0 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -295,7 +295,10 @@ std::unique_ptr getPointwiseHeuristics( }); int64_t max_dtype_size_bit_for_vectorization = 0; + // ugly WAR. + bool has_sub_byte = false; for (auto inp : vectorizable_inputs_outputs_entry.get()) { + has_sub_byte |= dataTypeSizeBit(inp->getDataType().value()) < 8; max_dtype_size_bit_for_vectorization = std::max( max_dtype_size_bit_for_vectorization, dataTypeSizeBit(inp->getDataType().value(), index_type)); @@ -484,8 +487,9 @@ std::unique_ptr getPointwiseHeuristics( } } + // If we have sub-byte data types, we wouldn't want to clamp vectorization factor to 1, otherwise we could end up with illegal array type with sub-byte length. params->vectorization_factor = std::min( - max_vect_factor, + has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor, vectorize_helper::getVectorizationFactor( runtime_info, largest_out, From 618e0aca7a71f828e467022b5016f9b1034cc8ce Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Oct 2025 14:23:41 -0700 Subject: [PATCH 2/7] adding python tests --- tests/python/direct/test_narrow_precision.py | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 0285258d6f6..9efec1ea4d7 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -19,6 +19,7 @@ linear_to_swizzled_128_4, round_up, activation_scale_to_nvfp4, + to_fp4, ) import pytest @@ -274,3 +275,36 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: ) assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) + +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float]) +def test_fp4_vectorization( + nvfuser_direct_test, + dtype, +): + + inputs = [ + torch.ones(4, 8, dtype=dtype, device="cuda"), + torch.ones(4, dtype=dtype, device="cuda"), + ] + + def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + T0 = fd.from_pytorch(inputs[0]) + T1 = fd.from_pytorch(inputs[0]) + T2 = fd.ops.cast(T0, DataType.Float) + cast_T1 = fd.ops.cast(T1, DataType.Float) + broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True]) + T3 = fd.ops.div(T2, broadcast_T1) + T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn) + T5 = fd.ops.reshape(T4, [m * k]) + fd.add_output(T5) + + o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) + + ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) + breakpoint() \ No newline at end of file From ab8f027d90b9ac62d723cf95dfe1d86b3361d6af Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Oct 2025 15:07:27 -0700 Subject: [PATCH 3/7] fixing fp4 type for validation --- tests/python/direct/test_narrow_precision.py | 4 ++-- tests/python/direct_utils/utils.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 9efec1ea4d7..d1f38a71f35 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -295,13 +295,13 @@ def test_fp4_vectorization( def nvfuser_fusion_id0(fd: FusionDefinition) -> None: T0 = fd.from_pytorch(inputs[0]) - T1 = fd.from_pytorch(inputs[0]) + T1 = fd.from_pytorch(inputs[1]) T2 = fd.ops.cast(T0, DataType.Float) cast_T1 = fd.ops.cast(T1, DataType.Float) broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True]) T3 = fd.ops.div(T2, broadcast_T1) T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn) - T5 = fd.ops.reshape(T4, [m * k]) + T5 = fd.ops.reshape(T4, [32]) fd.add_output(T5) o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index c69c5c28ab6..56e4c9d5c90 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -54,9 +54,13 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) # torch.allclose does not work with fp8 datatype, so cast to fp64. # However, casting complex values to real discards the imaginary # part, so skip complex dtypes. - if not ref_out.dtype.is_complex: + if ref_out.dtype == torch.float4_e2m1fn_x2: + ref_out = ref_out.view(torch.int8) + elif not ref_out.dtype.is_complex: ref_out = ref_out.to(torch.float64) - if not cap_out.dtype.is_complex: + if cap_out.dtype == torch.float4_e2m1fn_x2: + cap_out = cap_out.view(torch.int8) + elif not cap_out.dtype.is_complex: cap_out = cap_out.to(torch.float64) if not torch.allclose(ref_out, cap_out, equal_nan=True): return False From dabde6b7a0b8f77eaf3ba742bdd94f41f87e11b1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Oct 2025 15:19:53 -0700 Subject: [PATCH 4/7] lintrunner --- csrc/scheduler/pointwise.cpp | 4 +++- tests/python/direct/test_narrow_precision.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 98c26c52ed0..65e402455e9 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -487,7 +487,9 @@ std::unique_ptr getPointwiseHeuristics( } } - // If we have sub-byte data types, we wouldn't want to clamp vectorization factor to 1, otherwise we could end up with illegal array type with sub-byte length. + // If we have sub-byte data types, we wouldn't want to clamp vectorization + // factor to 1, otherwise we could end up with illegal array type with + // sub-byte length. params->vectorization_factor = std::min( has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor, vectorize_helper::getVectorizationFactor( diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index d1f38a71f35..482065ddbda 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -276,6 +276,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) + @pytest.mark.skipif( is_pre_blackwell(), reason="Only supported on blackwell and newer devices." ) @@ -287,7 +288,6 @@ def test_fp4_vectorization( nvfuser_direct_test, dtype, ): - inputs = [ torch.ones(4, 8, dtype=dtype, device="cuda"), torch.ones(4, dtype=dtype, device="cuda"), @@ -307,4 +307,4 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) - breakpoint() \ No newline at end of file + breakpoint() From ae68a4928667902bcee59a5f5ba49b9ea9b34b29 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Oct 2025 15:22:06 -0700 Subject: [PATCH 5/7] removing breakpoint --- tests/python/direct/test_narrow_precision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 482065ddbda..d64a03ad8f2 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -306,5 +306,4 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) - ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) - breakpoint() + ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) \ No newline at end of file From 2917c9d26d1ef52b798a5f205b3415d727c83193 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 23 Oct 2025 15:36:50 -0700 Subject: [PATCH 6/7] black --- tests/python/direct/test_narrow_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index d64a03ad8f2..555774b8efa 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -306,4 +306,4 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) - ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) \ No newline at end of file + ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) From f25dc4e5f2deb64a20cad0cb51cf7935f5b30de6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 26 Oct 2025 20:15:53 -0700 Subject: [PATCH 7/7] addressing review comments --- tests/python/direct_utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index 56e4c9d5c90..651591fcacd 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -54,6 +54,8 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) # torch.allclose does not work with fp8 datatype, so cast to fp64. # However, casting complex values to real discards the imaginary # part, so skip complex dtypes. + # Similarly, packed fp4 dtype cannot be compared neither, we view + # it as int8 and run comparison as-is. if ref_out.dtype == torch.float4_e2m1fn_x2: ref_out = ref_out.view(torch.int8) elif not ref_out.dtype.is_complex: