Skip to content

Conversation

@protonu
Copy link
Collaborator

@protonu protonu commented Sep 29, 2025

The PR:

  1. Adds a new IR node for Block Quantization to NVFP4. There are assumptions made that the block size is 16.
  2. This plumbs the IR node to a device function. The functions expects the inputs to be in registers and writes out the output blocks scales to global memory and quantized values to local memory.
  3. This only supports Float but can be easily extended to BF16.
  4. The device functions expects each thread to operate of 4 elements of the input.
  5. The unit test added does a bitwise comparison to the outputs from the quantization kernel generated by the normalization scheduler.
  6. The new IR node has been schedule to mimic the 1D and 2D pointwise scheduler:
[T0 T1 T2] 
merge all TN
TN-> [TN/4, 4(v)] // split 
[TN/4, 4(v)] -> TN/4, 1, 4(v) //split
[TN/4, 1, 4(v)] -> [TN/4/128(Bx), 128(Tx), 1, 4(v)]

For the 2D scheduling:

(m, n, k) -> (m, n*k)
(m, n*k) -> (m, n*k/4, 4)
(m, n*k/4, 4) -> (m, n*k/128, 32, 4)
(m, n*k/128, 32, 4) -> (m, 1, n*k/128, 32, 4)
(m, 1, n*k/128, 32, 4) -> (m/4, 4, 1, n*k/128, 32, 4)
(m/4(bidy), 4(tidy), 1, n*k/128(bidx), 32(tidx), 49(v))

@github-actions
Copy link

github-actions bot commented Sep 29, 2025

Review updated until commit c0cd7f9

Description

  • Add BlockQuantizationOp for NVFP4 quantization

  • Support float, bfloat16, and half data types

  • Validate scheduling constraints for block quantization

  • Integrate runtime kernel for block quantization


Changes walkthrough 📝

Relevant files
Enhancement
16 files
codegen.cpp
Handle BlockQuantizationOp code generation                             
+65/-0   
trivial_broadcast.cpp
Handle broadcast domains in block quantization                     
+11/-0   
index.cpp
Lower BlockQuantizationOp indices                                               
+42/-0   
utils.cpp
Add BlockQuantizationOp to TV ops                                               
+1/-0     
validation.cpp
Validate BlockQuantizationOp constraints                                 
+363/-3 
nodes.cpp
Implement BlockQuantizationOp logic                                           
+48/-0   
kernel.cpp
Detect BlockQuantizationOp in kernel                                         
+4/-0     
logical_domain_map.cpp
Handle non-mapping domains for block scales                           
+13/-0   
arith.cpp
Add blockQuantize operator                                                             
+81/-0   
compiled_kernel.cpp
Include block quantization kernel                                               
+9/-2     
block_quantization_kernels.cu
Implement NVFP4 block quantization kernel                               
+144/-0 
trivial_broadcast.h
Declare BlockQuantizationOp handler                                           
+2/-0     
index.h
Declare index lowering for BlockQuantizationOp                     
+1/-0     
dispatch.h
Register BlockQuantizationOp in dispatch                                 
+1/-0     
internal_nodes.h
Define BlockQuantizationOp class                                                 
+56/-0   
kernel.h
Track BlockQuantizationOp in kernel summary                           
+2/-0     
Bug fix
3 files
non_divisible_split.cpp
Skip predicate analysis for block scales                                 
+7/-1     
sync_information.cpp
Skip sync checks for block scaling output                               
+15/-0   
utils.cpp
Exclude BlockQuantizationOp from uniform siblings               
+1/-1     
Tests
1 files
test_low_precision_recipe.cpp
Add block quantization tests                                                         
+250/-7 
Additional files
2 files
CMakeLists.txt +1/-0     
arith.h +20/-0   

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The validation logic for BlockQuantizationOp assumes that the Group ID must be derived from the innermost logical domain ID. However, the current implementation may not correctly handle cases where merges are involved in the transformation from logical to loop domains, potentially leading to incorrect validation or missed errors.

void handle(BlockQuantizationOp* bqop) final {
  auto inp_tv = bqop->input(0)->as<TensorView>();
  auto quantized_output = bqop->quantizedOutput()->as<TensorView>();
  auto block_scaling_factor = bqop->blockScales()->as<TensorView>();

  NVF_ERROR_EQ(
      inp_tv->getMemoryType(),
      MemoryType::Local,
      "Input must be a local memory tensor. Found: ",
      inp_tv->getMemoryType());

  NVF_ERROR_EQ(
      quantized_output->getMemoryType(),
      MemoryType::Local,
      "Quantized output must be a local memory tensor. Found: ",
      quantized_output->getMemoryType());

  NVF_ERROR_EQ(
      block_scaling_factor->getMemoryType(),
      MemoryType::Global,
      "Block scaling factor must be a global memory tensor. Found: ",
      block_scaling_factor->getMemoryType());

  // outputs have the same allocation domain
  // as the loop domain. This has to be later
  // relaxed for the scaling factors.
  NVF_ERROR(
      quantized_output->hasAllocation() == false,
      "Quantized output must not have an allocation domain.");
  NVF_ERROR(
      block_scaling_factor->hasAllocation() == false,
      "Block scaling factor must not have an allocation domain.");

  // Check that it either had vectorized ID or grouped ID
  // not both and the extent is either 4(FP32) or 8(BF16)
  IterDomain* grouped_id = nullptr;
  IterDomain* thread_x = nullptr;
  IterDomain* block_x = nullptr;
  IterDomain* thread_z = nullptr;
  IterDomain* block_z = nullptr;

  for (const auto& loop_id : block_scaling_factor->getLoopDomain()) {
    if (loop_id->getParallelType() == ParallelType::Group) {
      grouped_id = loop_id;
    }
    if (loop_id->getParallelType() == ParallelType::Serial ||
        loop_id->getParallelType() == ParallelType::Unswitch ||
        loop_id->getParallelType() == ParallelType::Unroll) {
      // Check this is ID has a constant extent and is 1
      NVF_ERROR(
          loop_id->extent()->isConstInt(),
          "Expected constant extent for Serial ID in BlockQuantizationOp");
      NVF_ERROR(
          loop_id->extent()->evaluate().as<int64_t>() == 1,
          "Expected extent of 1");
    }
  }

  auto parallel_domains_map =
      ir_utils::getParallelDomains(block_scaling_factor);

  if (parallel_domains_map.find(ParallelType::TIDx) !=
      parallel_domains_map.end()) {
    thread_x = parallel_domains_map.at(ParallelType::TIDx);
  }
  if (parallel_domains_map.find(ParallelType::BIDx) !=
      parallel_domains_map.end()) {
    block_x = parallel_domains_map.at(ParallelType::BIDx);
  }
  if (parallel_domains_map.find(ParallelType::TIDz) !=
      parallel_domains_map.end()) {
    thread_z = parallel_domains_map.at(ParallelType::TIDz);
  }
  if (parallel_domains_map.find(ParallelType::BIDz) !=
      parallel_domains_map.end()) {
    block_z = parallel_domains_map.at(ParallelType::BIDz);
  }

  NVF_ERROR(
      grouped_id != nullptr,
      "One of the output IDs must be grouped for "
      "BlockQuantizationOp: ",
      bqop->toString());

  NVF_ERROR(
      thread_x != nullptr && block_x != nullptr,
      "Need to have both TIDx and BIDx when using BlockQuantizationOp: ",
      bqop->toString());

  NVF_ERROR(
      !thread_z && !block_z,
      "Parallelization along z axis is not supported for "
      "BlockQuantizationOp: ",
      bqop->toString());

  auto inner_extent = grouped_id->extent()->evaluate().as<int64_t>();
  auto input_dtype = inp_tv->dtype();

  NVF_ERROR(
      (inner_extent == 4 && input_dtype == DataType::Float) ||
          (inner_extent == 8 &&
           (input_dtype == DataType::BFloat16 ||
            input_dtype == DataType::Half)),
      "The vectorized/grouped dimension must be  4 (FP32) or 8 "
      "(BF16). Found: ",
      inner_extent,
      ". Expr: ",
      bqop->toString());

  // Get the ID marked as Group
  IterDomain* new_grouped_id = nullptr;
  for (auto loop_id : quantized_output->getLoopDomain()) {
    if (loop_id->getParallelType() == ParallelType::Group) {
      new_grouped_id = loop_id;
    }
  }

  NVF_ERROR(
      new_grouped_id != nullptr,
      "Expected a valid loop grouped ID for BlockQuantizationOp: ",
      bqop->toString());

  auto last_split_seen =
      checkGroupIDDerivedFromLastLogicalIDs(new_grouped_id, quantized_output);

  // if last split seen is null there are two possibilities
  // 1) Group ID is directly from logical domain -> valid
  // 2) There was a merge right before Group ID
  IterDomain* restart_traversal_from = nullptr;

  if (last_split_seen == nullptr) {
    auto ids_in_logical = IterVisitor::getInputsTo({new_grouped_id});
    // Check all these ID have constant extents
    for (auto id : ids_in_logical) {
      auto iter_domain = id->as<IterDomain>();
      NVF_ERROR(
          iter_domain->extent()->isConstInt(),
          "Expected all IDs feeding directly into Group ID to have constant "
          "extents for BlockQuantizationOp: ",
          quantized_output->toString());
    }

    // Check that there are logical IDs left to derive thread IDs
    NVF_ERROR(
        ids_in_logical.size() < quantized_output->getLogicalDomain().size(),
        "There aren't enough logical IDs to derive thread Ids ",
        quantized_output->toString());

    restart_traversal_from =
        quantized_output->getLogicalDomain()
            [quantized_output->getLogicalDomain().size() -
             ids_in_logical.size() - 1];
  } else {
    // Go the outer ID, we should have come up from the inner split.
    restart_traversal_from = last_split_seen->outer();
  }

  traverseFromSplitToThreadX(restart_traversal_from, quantized_output);
}
Performance Issue

The block quantization kernel uses division by a constant (6.0) which could be replaced with multiplication by the reciprocal for better performance. This optimization is explicitly mentioned in a TODO comment but not implemented.

float scaled_max = block_max / 6.000000000e+00f;
Possible Issue

The blockQuantize function creates a BlockQuantizationOp without passing the logical_index parameter, which is required by the constructor. This could lead to undefined behavior or incorrect code generation.

IrBuilder::create<BlockQuantizationOp>(block_scales, quantized_tensor, input);

@protonu protonu changed the title This creates a new node for Block Quantization to NVFP4 and plumbs it to a device function. Create a new node for Block Quantization to NVFP4 and plumbs it to a device function. Sep 29, 2025
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review in detail on all the relaxed checks.

My main question is how are we handling/planning to handle indexing in runtime function? Are we going to just restrict the scheduler to ensure they comply with the index requirement from the runtime function, that felt like too restrictive to me.


// This division should be replaced with a multiplication
// by a reciprocal for better performance.
float scaled_max = block_max / 6.000000000e+00f;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm uncertain how the math would work with global scaling factor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math in the device function is not correct (doesn't match with the Python reference in test_narrow_precision.py.
In a branch I have support for global scaling factor and a modified kernel with the new math.
I think I'll update this PR with the new math.

The only problem is that, I'll need to update Xiang's old tests as well to reflect that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'll leave the math in the device function as is for now and have a separate PR which updates the math in the device function and the older C++ tests.

@protonu
Copy link
Collaborator Author

protonu commented Sep 30, 2025

@jjsjann123

I didn't review in detail on all the relaxed checks.

My main question is how are we handling/planning to handle indexing in runtime function? Are we going to just restrict the scheduler to ensure they comply with the index requirement from the runtime function, that felt like too restrictive to me.

For now that's what I was doing.
I was hoping if we can support 1 or 2 types of scheduling, then we can get away with it.
It works for the simplest 1D scheduling, but you are right in that as the scheduling gets more complex, it'll be hard to track it in the device function. Let me think about this a bit.

@protonu protonu requested a review from jjsjann123 September 30, 2025 17:24
@protonu
Copy link
Collaborator Author

protonu commented Sep 30, 2025

@jjsjann123 I updated the device function (and the codegen).
The device function still writes to global memory but the output index computation is done by the codegen.

@protonu protonu requested a review from naoyam October 1, 2025 17:06
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to revisit the RP and clean it up.

@protonu protonu requested a review from naoyam October 10, 2025 14:21
@protonu protonu marked this pull request as ready for review October 10, 2025 14:48
@protonu protonu changed the title Create a new node for Block Quantization to NVFP4 and plumbs it to a device function. Create a new node for Block Quantization to NVFP4 and plumb it to a device function. Oct 10, 2025
@protonu
Copy link
Collaborator Author

protonu commented Oct 10, 2025

!test

FusionGuard fg(fusion.get());
createNVFP4QunatizationFusion(fusion.get(), DataType::Float);

FusionExecutorCache fec(std::move(fusion));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you mention the quantization originally done by @zasdfgbnm has some bugs? Is that fixed?


// I'd like to check that the inner dimension of the input
// is divisble by 16.
void handle(BlockQuantizationOp* bqop) final {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam This is the validation function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a list of things to validate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the comment with a list of things I want to check.
I need to rewrite the code for validation and the second half is not correct. I'll send another commit to fix that.

@protonu protonu requested a review from naoyam October 17, 2025 16:40
@protonu
Copy link
Collaborator Author

protonu commented Oct 17, 2025

!test

@nvMelissa
Copy link
Collaborator

Target to complete is 10/31 cc: @protonu @naoyam

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants