-
Couldn't load subscription status.
- Fork 67
Hacky patch to support correct vectorization factor of nvfp4 for pointwise scheduler #5428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2ee3fba
618e0ac
ab8f027
dabde6b
ae68a49
2917c9d
a8c9b43
f25dc4e
d88ee26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -295,7 +295,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
| }); | ||
|
|
||
| int64_t max_dtype_size_bit_for_vectorization = 0; | ||
| // ugly WAR. | ||
| bool has_sub_byte = false; | ||
| for (auto inp : vectorizable_inputs_outputs_entry.get()) { | ||
| has_sub_byte |= dataTypeSizeBit(inp->getDataType().value()) < 8; | ||
| max_dtype_size_bit_for_vectorization = std::max( | ||
| max_dtype_size_bit_for_vectorization, | ||
| dataTypeSizeBit(inp->getDataType().value(), index_type)); | ||
|
|
@@ -484,8 +487,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
| } | ||
| } | ||
|
|
||
| // If we have sub-byte data types, we wouldn't want to clamp vectorization | ||
| // factor to 1, otherwise we could end up with illegal array type with | ||
| // sub-byte length. | ||
| params->vectorization_factor = std::min( | ||
| max_vect_factor, | ||
| has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks - I need something like this as well. I thought I may need this to have a minimum of 4/8(bf16). I guess these are all computed at compile time so there's no information about alignment yet. Can increasing the vectorization width (or having any vectorization at all) lead to incorrect behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't have caused more issue than what would have already been there. This only modifies heuristics on the lower bound of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should assert on the params->vectorization_factor here in the case that we have a packed fp4 type and we aren't able to vectorize to 2 or more? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ^ I would prefer not. generally speaking, earlier error seems like a good idea. But vectorization should not be the answer. I don't know how to fix that properly yet and I wanted it to work for now. That's why I was calling it a hacky patch. 😰 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't understand what you meant by "the inline". |
||
| vectorize_helper::getVectorizationFactor( | ||
| runtime_info, | ||
| largest_out, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| linear_to_swizzled_128_4, | ||
| round_up, | ||
| activation_scale_to_nvfp4, | ||
| to_fp4, | ||
| ) | ||
|
|
||
| import pytest | ||
|
|
@@ -274,3 +275,35 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: | |
| ) | ||
|
|
||
| assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| is_pre_blackwell(), reason="Only supported on blackwell and newer devices." | ||
| ) | ||
| @pytest.mark.skipif( | ||
| not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float]) | ||
| def test_fp4_vectorization( | ||
| nvfuser_direct_test, | ||
| dtype, | ||
| ): | ||
| inputs = [ | ||
| torch.ones(4, 8, dtype=dtype, device="cuda"), | ||
| torch.ones(4, dtype=dtype, device="cuda"), | ||
| ] | ||
|
|
||
| def nvfuser_fusion_id0(fd: FusionDefinition) -> None: | ||
| T0 = fd.from_pytorch(inputs[0]) | ||
| T1 = fd.from_pytorch(inputs[1]) | ||
| T2 = fd.ops.cast(T0, DataType.Float) | ||
| cast_T1 = fd.ops.cast(T1, DataType.Float) | ||
| broadcast_T1 = fd.ops.broadcast(cast_T1, [False, True]) | ||
| T3 = fd.ops.div(T2, broadcast_T1) | ||
| T4 = fd.ops.cast(T3, DataType.Float4_e2m1fn) | ||
| T5 = fd.ops.reshape(T4, [32]) | ||
| fd.add_output(T5) | ||
|
|
||
| o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) | ||
|
|
||
| ref_o = to_fp4(inputs[0].to(torch.float) / inputs[1].unsqueeze(-1)).reshape(-1) | ||
|
Comment on lines
+307
to
+309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can somebody teach me how validation is done here? IIUC, Please don't use names like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And also, this is for vectorization, but I'm not seeing where the vectorization is validated. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,9 +54,15 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) | |
| # torch.allclose does not work with fp8 datatype, so cast to fp64. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be useful to add to the comment here about the packed fp4. |
||
| # However, casting complex values to real discards the imaginary | ||
| # part, so skip complex dtypes. | ||
| if not ref_out.dtype.is_complex: | ||
| # Similarly, packed fp4 dtype cannot be compared neither, we view | ||
| # it as int8 and run comparison as-is. | ||
| if ref_out.dtype == torch.float4_e2m1fn_x2: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rdspring1 probably not the cleanest thing, just a naive python patch |
||
| ref_out = ref_out.view(torch.int8) | ||
| elif not ref_out.dtype.is_complex: | ||
| ref_out = ref_out.to(torch.float64) | ||
| if not cap_out.dtype.is_complex: | ||
| if cap_out.dtype == torch.float4_e2m1fn_x2: | ||
| cap_out = cap_out.view(torch.int8) | ||
| elif not cap_out.dtype.is_complex: | ||
| cap_out = cap_out.to(torch.float64) | ||
| if not torch.allclose(ref_out, cap_out, equal_nan=True): | ||
| return False | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please give more details rather than just saying it's ugly. Why is it ugly? What should have been done instead?
Also, this variable is only used hundreds of lines below. For better code readability, it should be moved down there.