From 8e3842feedeb54b3a2768c785cbc0dfee7b84525 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 10 Sep 2025 17:33:08 -0700 Subject: [PATCH 01/79] creating a new node --- CMakeLists.txt | 1 + csrc/codegen.cpp | 55 +++++-- .../analysis/sync_information.cpp | 11 +- .../analysis/trivial_broadcast.cpp | 21 ++- csrc/device_lower/pass/allocation.cpp | 1 + csrc/device_lower/pass/index.cpp | 17 +++ csrc/device_lower/pass/index.h | 1 + csrc/device_lower/pass/predicate.cpp | 10 +- csrc/device_lower/utils.cpp | 1 + csrc/device_lower/validation.cpp | 13 +- csrc/dispatch.h | 1 + csrc/ir/internal_nodes.h | 39 +++++ csrc/ir/nodes.cpp | 32 ++++ csrc/ir/utils.cpp | 9 +- csrc/kernel.cpp | 4 + csrc/kernel.h | 2 + csrc/kernel_ir_dispatch.cpp | 3 + csrc/ops/arith.cpp | 64 ++++++++ csrc/ops/arith.h | 14 ++ csrc/runtime/compiled_kernel.cpp | 18 ++- csrc/scheduler/registry.cpp | 1 + runtime/block_quantization_kernels.cu | 108 ++++++++++++++ tests/cpp/test_low_precision_recipe.cpp | 140 +++++++++++++++++- 23 files changed, 528 insertions(+), 38 deletions(-) create mode 100644 runtime/block_quantization_kernels.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ef2afc4d7b..f54d7c12316 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1343,6 +1343,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/block_sync_default.cu ${NVFUSER_ROOT}/runtime/block_welford_outer.cu ${NVFUSER_ROOT}/runtime/block_layout.cu + ${NVFUSER_ROOT}/runtime/block_quantization_kernels.cu ${NVFUSER_ROOT}/runtime/broadcast.cu ${NVFUSER_ROOT}/runtime/casts.cu ${NVFUSER_ROOT}/runtime/cluster.cu diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index bf4ab5cbf42..1d77bb0a71e 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -471,9 +471,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto space_type = kernel_summary.largest_smem_data_type; indent() << "nvfuser_index_t block_size = " "blockDim.x*blockDim.y*blockDim.z;\n"; - indent() << space_type << " *shared_mem_var = " - << "static_cast<" << space_type << "*>(" - << "shared_mem);\n"; + indent() << space_type << " *shared_mem_var = " << "static_cast<" + << space_type << "*>(" << "shared_mem);\n"; indent() << space_type << " *shared_mem_avg = shared_mem_var + block_size;\n"; indent() << space_type @@ -1356,9 +1355,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { case BinaryOpType::Add: if (sop->in()->dtype() == DataType::Int) { // atomicAdd does not provide an overload for int64_t - code_ << "atomicAdd(" - << "reinterpret_cast(&" << dst << "), " - << "static_cast(" << src << "));\n"; + code_ << "atomicAdd(" << "reinterpret_cast(&" + << dst << "), " << "static_cast(" << src + << "));\n"; } else { code_ << "atomicAdd(" << "&" << dst << ", " << src << ");\n"; } @@ -1668,6 +1667,37 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall("topk::blockTopK", template_args, func_args) << ";\n"; } + void handle(const BlockQuantizationOp* bqop) final { + // Get the vectorization size for items per thread + const auto input = bqop->in()->as(); + auto vectorized_input_to_reshape = input->view()->definition()->input(0); + int64_t vector_word_size = ir_utils::getVectorizeSize( + vectorized_input_to_reshape->as()); + NVF_ERROR( + vector_word_size == 4, + "Vectorization size should be 4 for " + "BlockQuantizationOp: ", + bqop->toString()); + ArgumentBuilder template_args; + template_args.arg(vector_word_size); // ITEMS_PER_THREAD + + // Function arguments + ArgumentBuilder func_args; + + // We pass the entire Tensors without any indices. + // The device functions will write out values based on + // it's parallelization. + // First argument: input data array + // Second argument: quantized output + // Third argument: block scale output + func_args.arg(ir_utils::varName(bqop->input(0))); + func_args.arg(ir_utils::varName(bqop->quantizedOutput())); + func_args.arg(ir_utils::varName(bqop->blockScales())); // DataT + + indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) + << ";\n"; + } + void handle(const ScanOp* scan) final { NVF_ERROR(isAligned(), "Scan with divergent threads not supported"); @@ -1749,8 +1779,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // This is slightly different from getReductionOp std::stringstream lambda; lambda << "[](const " << input->dtype() << "& a, const " << input->dtype() - << "& b) " - << "{ return " + << "& b) " << "{ return " << genBinaryOp(scan->opType(), input->dtype(), "a", "b") << "; }"; func_args.arg(lambda.str()); @@ -2089,8 +2118,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "#pragma unroll\n"; indent() << "for (int i = 0; i < " << ldst->groupSize() << "; ++i) {\n"; indent() << kTab << genVariableName(out_ti->view()) << "[(" - << genInline(out_ti->index()) << ") + i]" - << " = " << gen(ldst->in()) << ";\n"; + << genInline(out_ti->index()) << ") + i]" << " = " + << gen(ldst->in()) << ";\n"; indent() << "}\n"; } @@ -2197,8 +2226,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const bool has_grid_reduce = domain->hasGridReduction(); if (!has_block_reduce && !has_grid_reduce) { - indent() << "welfordCombine (" - << "\n"; + indent() << "welfordCombine (" << "\n"; indent() << kTab << gen(out_avg) << ",\n"; indent() << kTab << gen(out_var) << ",\n"; indent() << kTab << gen(out_N) << ",\n"; @@ -4132,8 +4160,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // actual argument value like T0[i * 4 + j]. << (as_utility ? prefix + std::to_string(counter) : gen(register_)) - << "[" << i << "]" - << ")"; + << "[" << i << "]" << ")"; } } else { (*asm_target) << "\"" << constraint << "\"("; diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index a63312b1ab4..2914f2ba4ec 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -232,6 +232,10 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { continue; } + if (expr->isA() && producer_i == producer->nDims() - 2) { + continue; + } + producer_parallel_ids[getParallelTypeBitMapOffset(producer_ptype)] = producer_axis; } @@ -499,7 +503,10 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { if (error_on_failure) { if (raw_dims.hasBID()) { NVF_ERROR( - producer->getMemoryType() == MemoryType::Global, + producer->getMemoryType() == MemoryType::Global || + consumer->definition()->isA() || + // producer->definition()->isA() || + consumer->uses()[0]->isA(), "Inconsistent parallelization found between T", producer->name(), " (", @@ -516,6 +523,8 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { NVF_ERROR( ir_utils::isLdMatrixOp(producer->definition()) || ir_utils::isStMatrixOp(consumer->definition()) || + consumer->definition()->isA() || + producer->definition()->isA() || producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared || producer->getMemoryType() == MemoryType::Tensor, diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index a70b2853527..0164df2f521 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -18,6 +18,14 @@ ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { // Initialize the origin map with input broadcast domains auto inputs = fusion->inputsAndCreated(); + auto exprs_ = fusion->exprs(); + + auto bq_ops = ir_utils::filterByType(exprs_); + if (bq_ops.size() == 1) { + inputs.push_back( + static_cast(bq_ops.vector()[0]->blockScales())); + } + for (const auto fusion_input_tv : ir_utils::filterByType(inputs)) { for (auto logical_id : fusion_input_tv->getLogicalDomain()) { @@ -123,6 +131,10 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { for (auto consumer : ir_utils::filterByType(expr->outputs())) { auto p2c_map = PairwiseLogicalDomainMap(producer, consumer) .mapProducerToConsumer(&producer_broadcasts); + auto consumer_is_block_quantization_scales = + expr->isA() && + consumer == expr->as()->blockScales(); + for (const auto& kv : p2c_map) { auto p_id = kv.first; auto c_id = kv.second; @@ -132,7 +144,8 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { !c_id->isBroadcast() && !c_id->isReduction(); auto it = broadcast_origin_map_.find(p_id); NVF_ERROR( - it != broadcast_origin_map_.end(), + it != broadcast_origin_map_.end() && + !consumer_is_block_quantization_scales, "Broadcast origin info not found for producer broadcast domain: ", p_id->toString(), " of ", @@ -146,8 +159,10 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { } else { // Not concretized yet. Propagate forward the origin info. auto& consumer_origins = broadcast_origin_map_[c_id]; - for (auto origin : producer_origins) { - consumer_origins.insert(origin); + if (!consumer_is_block_quantization_scales) { + for (auto origin : producer_origins) { + consumer_origins.insert(origin); + } } consumer_origins.insert(c_id); } diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index b28bd3e3914..038873f73f5 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -251,6 +251,7 @@ class AllocationDomainSetup : private kir::IrVisitor { // aliasTensorProducer, in which case it will not be allocated. NVF_ERROR( producer_tv->isFusionInput() || + producer_tv->definition()->isA() || GpuLower::current()->getTensorProducerAlias(producer_tv) != nullptr, "Expected a fusion input or aliased tensor but found: ", diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 8565227ff45..e3823839eeb 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -405,6 +405,23 @@ void IndexLowering::handle(const TopKOp* top) { GpuLower::current()->propagateExprInfo(top, back()); } +void IndexLowering::handle(const BlockQuantizationOp* bqop) { + const auto in = lowerSrcIndex(bqop->in(), bqop->quantizedOutput()); + + // For the two outputs, we don't really need indices. + const auto out_scales = IrBuilder::create( + static_cast(bqop->blockScales()), + IrBuilder::create(0L, DataType::Index)); + + const auto out_quantized = IrBuilder::create( + static_cast(bqop->quantizedOutput()), + IrBuilder::create(0L, DataType::Index)); + + pushBack( + IrBuilder::create(out_scales, out_quantized, in)); + GpuLower::current()->propagateExprInfo(bqop, back()); +} + void IndexLowering::handle(const SelectOp* sop) { auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0)); auto lowered_index_cast = lowered_index; diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 070fa97a876..a7906069912 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -57,6 +57,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const ScatterOp*) final; void handle(const ArgsortOp*) final; void handle(const TopKOp*) final; + void handle(const BlockQuantizationOp*) final; void handle(const RNGOp*) final; void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 664a63f4f56..73667b55650 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -67,13 +67,17 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { auto vec_expr = ite->thenBody()[0]; NVF_ERROR( vec_expr->isA() || vec_expr->isA() || - vec_expr->isA() || vec_expr->isA(), + vec_expr->isA() || + vec_expr->isA() || + vec_expr->isA(), "Vectorize predicate exprs only supported on set operations."); NVF_ERROR( - ir_utils::isTvOp(vec_expr), + ir_utils::isTvOp(vec_expr) || + vec_expr->isA(), "Vectorize predicate exprs only supported on tensor view " "operations."); - if (!vec_expr->inputs()[0]->isConstScalar()) { + if (!vec_expr->inputs()[0]->isConstScalar() && + !vec_expr->isA()) { conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, GpuLower::current()->info().threadPredicateMap().getPredicate( diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 1b8105744f6..7c9e92bd4c7 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -126,6 +126,7 @@ bool isTvOp(const Expr* expr) { CatOp, ScanOp, PreprocessGroupedMatmulInputSf, + BlockQuantizationOp, kir::AllocTMem, kir::GridReduction, kir::GroupedGridReduction, diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index a7d1b036560..d1f2c2edff9 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -47,7 +47,10 @@ class ValidateSiblings : public IterVisitor { using IterVisitor::handle; void dispatch(Expr* expr) final { - if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2) { + // Skip BlockQuantization. + // It has sibling outputs which differ from each other + if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2 || + expr->isA()) { IterVisitor::dispatch(expr); return; } @@ -707,9 +710,12 @@ class VectorizeValidator : public OptInDispatch { } auto ldst = dynamic_cast(tv->definition()); + auto is_block_quantization_op = + dynamic_cast(tv->definition()); bool is_ldmatrix_trans = ldst != nullptr && mma_utils::isLdMatrixTranspose(ldst); - if (!is_ldmatrix_trans && name.compare("consumer") != 0) { + if (!is_ldmatrix_trans && name.compare("consumer") != 0 && + !is_block_quantization_op) { // ldmatrix.trans is a hardware transpose instruction that can do // "vectorized" read from discontiguous memory // We don't think allocation domain of consumer is used in allocation. We @@ -1016,7 +1022,8 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { (def->isA() && def->as()->serialGridReductionRequested()) || (def->isA() && - def->as()->getUnaryOpType() == UnaryOpType::Cast), + def->as()->getUnaryOpType() == UnaryOpType::Cast) || + def->isA(), "Vectorized accesses cannot be inline with computation: ", (def == nullptr ? tv->toString() : def->toString())); } diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 50c9f68ac0e..e506708ec10 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -109,6 +109,7 @@ class Val; f(ScaledMmaOp); \ f(CutlassNvfp4GroupedMmaOp); \ f(PreprocessGroupedMatmulInputSf); \ + f(BlockQuantizationOp); \ f(TopKOp); \ f(ScanOp); \ f(Merge); \ diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 1b044f88a35..bdc9d6a23c8 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -3399,4 +3399,43 @@ class PreprocessGroupedMatmulInputSf : public Expr { } }; +class BlockQuantizationOp : public Expr { + public: + using Expr::Expr; + + BlockQuantizationOp( + IrBuilderPasskey, + Val* output_scales, + Val* output, + Val* input); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + Val* blockScales() const { + return output(1); + } + + Val* quantizedOutput() const { + return output(0); + } + + Val* in() const { + return input(0); + } + + int64_t blockSize() const { + return 16; + } + + const char* getOpString() const override { + return "BlockQuantizationOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; +}; + } // namespace nvfuser diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 87b4e77c232..57a280975ce 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -6298,4 +6298,36 @@ std::vector PreprocessGroupedMatmulInputSf::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(PreprocessGroupedMatmulInputSf) +BlockQuantizationOp::BlockQuantizationOp( + IrBuilderPasskey passkey, + Val* output_scales, + Val* output, + Val* input) + : Expr(passkey) { + addInput(input); + addOutput(output); + addOutput(output_scales); +} + +std::string BlockQuantizationOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "(" << blockScales()->toString() << ", " + << quantizedOutput()->toString() + << ") = block_quantize(" << in()->toString() << ")\n"; + return ss.str(); +} + +std::string BlockQuantizationOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "BlockQuantizationOp can not be printed inline"); +} + +std::vector BlockQuantizationOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + // This is a placeholder, currently we don't have a fallback kernel available + NVF_THROW("BlockQuantizationOp evaluation not yet implemented"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(BlockQuantizationOp) + } // namespace nvfuser diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index cc34cba693e..ffec2e039bf 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1339,7 +1339,7 @@ bool hasTrivialAllocationDomain(const TensorView* tv) { alloc | TensorDomain::kNoReductions | TensorDomain::kNoBroadcasts); } bool hasUniformSiblings(Expr* expr) { - return !expr->isOneOf(); + return !expr->isOneOf(); } bool isPartitionedLoop(const TensorView* tv, IterDomain* id) { @@ -1516,6 +1516,13 @@ kir::ForLoop* createRangeLoop(int64_t size) { } TensorView* getTvOutput(const Expr* expr) { + if (expr->isA()) { + // BlockQuantizationOp has multiple outputs + // but for now we only look at the quantized output + // which cleanly maps to the input. + return getTv(expr->as()->quantizedOutput()); + } + for (auto out : expr->outputs()) { if (auto tv = getTv(out)) { return tv; diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index ddcfb46ddbf..7bb5dbb9eb8 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -306,6 +306,10 @@ class KernelIrScanner : private IrVisitor { summary_.has_topk = true; } + void handle(BlockQuantizationOp* bqop) final { + summary_.has_block_quantize_op = true; + } + void handle(ScanOp* scan) final { summary_.has_scan = true; } diff --git a/csrc/kernel.h b/csrc/kernel.h index 7ceba7b669a..4d3534219cc 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -145,6 +145,8 @@ struct KernelSummary { //! Do we have any preprocess op? bool has_preprocess_grouped_matmul_input_sf = false; + bool has_block_quantize_op = false; + //! Do we have any topk op? bool has_topk = false; diff --git a/csrc/kernel_ir_dispatch.cpp b/csrc/kernel_ir_dispatch.cpp index 3761cc5faed..06759bbbf64 100644 --- a/csrc/kernel_ir_dispatch.cpp +++ b/csrc/kernel_ir_dispatch.cpp @@ -24,6 +24,9 @@ void IrVisitor::handle(ForLoop* fl) { scope_exprs_.push_back(fl); auto body_exprs = std::vector(fl->body().exprs()); for (auto expr : body_exprs) { + // if (expr->isA()) { + // continue; + // } dispatch(expr); } scope_exprs_.pop_back(); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 67a3955108d..34c46b5a088 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2634,4 +2634,68 @@ TensorView* prefixSum(TensorView* tv, int64_t dim) { /*init=*/tv->fusion()->zeroVal(tv->dtype())); } +BlockQuantizationResults block_quantize(TensorView* input) { + auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); + + auto inp_domain = + TensorDomain::noReductions(reshaped_input->getLogicalDomain()); + + // Validate input tensor is not zero-dimensional + NVF_CHECK( + !inp_domain.empty(), + "Block quantization does not support zero-dimensional tensors"); + + // Validate input data type - typically requires floating point input + NVF_CHECK( + isFloatingPointType(input->getDataType().value()), + "Block quantization expects floating point input but got ", + input->getDataType().value()); + + // Create output domain for quantized tensor (same shape as input) + std::vector quantized_out_domain; + quantized_out_domain.reserve(inp_domain.size()); + + for (auto inp_domain_ptr : inp_domain) { + quantized_out_domain.push_back(inp_domain_ptr->cloneWithoutRFactor()); + } + + // Create output domain for block scales + // Block scales typically have reduced dimensions based on block size + // For now, assuming block size of 16 and reducing the last dimension + std::vector scales_out_domain; + scales_out_domain.reserve(inp_domain.size()); + + for (size_t i = 0; i < inp_domain.size(); ++i) { + if (i == inp_domain.size() - 1) { + scales_out_domain.push_back( + IterDomainBuilder( + input->fusion()->zeroVal(), input->fusion()->oneVal()) + .iter_type(IterType::Broadcast) + .expanded_extent(IrBuilder::create(1, DataType::Index)) + .build()); + } else { + scales_out_domain.push_back(inp_domain[i]->cloneWithoutRFactor()); + } + } + + // Create output tensors + TensorView* quantized_tensor = IrBuilder::create( + IrBuilder::create( + quantized_out_domain, + TensorDomain::getContiguityFilledWith(quantized_out_domain, true)), + DataType::Float4_e2m1fn); // Quantized output using 32-bit integers + + TensorView* block_scales = IrBuilder::create( + IrBuilder::create( + scales_out_domain, + TensorDomain::getContiguityFilledWith(scales_out_domain, true)), + DataType::Float8_e4m3fn); // Scales maintain input data type + + // Create the block quantization operation + IrBuilder::create( + block_scales, quantized_tensor, reshaped_input); + + return BlockQuantizationResults(block_scales, quantized_tensor); +} + } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 06d36e5591d..45c75b157ad 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -823,4 +823,18 @@ NVF_API inline TensorView* cumsum(TensorView* tv, int64_t dim) { return prefixSum(tv, dim); } +struct BlockQuantizationResults { + public: + TensorView* block_scales = nullptr; + TensorView* quantized_tensor = nullptr; + + explicit BlockQuantizationResults( + TensorView* in_block_scales, + TensorView* in_quantized_tensor) + : block_scales(in_block_scales), quantized_tensor(in_quantized_tensor) {} +}; + +//! Expose block size as a parameter. Currently only supports 16. +NVF_API BlockQuantizationResults block_quantize(TensorView* input); + } // namespace nvfuser diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index abfb7a72cd3..8f609ae24fe 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include @@ -297,10 +298,9 @@ std::string disassembleBinary( // so I have to dump the stdin to a temp file and let nvdisasm read it. I am // hoping that nvdisasm will support reading from stdin one day. std::stringstream ss; - ss << "export PATH=$PATH:/usr/local/cuda/bin;" - << "TMPFILE=$(mktemp);" - << "cat>$TMPFILE;" - << "nvdisasm $TMPFILE " << nvdisasm_args << "; rm $TMPFILE"; + ss << "export PATH=$PATH:/usr/local/cuda/bin;" << "TMPFILE=$(mktemp);" + << "cat>$TMPFILE;" << "nvdisasm $TMPFILE " << nvdisasm_args + << "; rm $TMPFILE"; auto command = ss.str(); execl("/bin/bash", "bash", "-c", command.c_str(), NULL); @@ -1058,7 +1058,8 @@ std::string _getStructuredCode( bool has_argsort = false, bool has_topk = false, bool has_scan = false, - bool has_block_layout = false) { + bool has_block_layout = false, + bool has_block_quantize_op = false) { // generating cuda code; std::string code = ""; @@ -1098,6 +1099,10 @@ std::string _getStructuredCode( code += nvfuser_resources::block_layout_cu; } + if (has_block_quantize_op) { + code += nvfuser_resources::block_quantization_kernels_cu; + } + code += "\nnamespace " + CompiledKernel::kernelNamespace() + " {\n\n"; code += kernel_str; code += "\n} // namespace " + CompiledKernel::kernelNamespace() + "\n"; @@ -1445,7 +1450,8 @@ std::string CompiledKernel::getStructuredCode() const { kernel()->summary().has_argsort, kernel()->summary().has_topk, kernel()->summary().has_scan, - kernel()->summary().has_preprocess_grouped_matmul_input_sf); + kernel()->summary().has_preprocess_grouped_matmul_input_sf, + kernel()->summary().has_block_quantize_op); } std::string CompiledKernel::disassembledKernelSASS() const { diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 9f5e4460fa7..de767d13b15 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -60,6 +60,7 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) { // TODO: remove this once we have a scheduler for it PreprocessGroupedMatmulInputSf, TopKOp, + BlockQuantizationOp, ScanOp>(fusion)) { scheduler_debug_utils::canScheduleRejectReason( scheduler_type, "Has unsupported ops"); diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu new file mode 100644 index 00000000000..83e2b067fbf --- /dev/null +++ b/runtime/block_quantization_kernels.cu @@ -0,0 +1,108 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +namespace nvf { +namespace bq { + +__device__ __inline__ void quadMaxReduction(float& local_max) { + // The mask 0xffffffff indicates all 32 threads in the warp are participating. + unsigned int mask = 0xffffffff; + + // --- Reduction Step 1 --- + // Exchange and compare with thread 2 lanes away within the quad. + // e.g., thread 0 exchanges with 2; thread 1 with 3. + // The XOR pattern naturally keeps the operation within each quad. + local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 2)); + + // --- Reduction Step 2 --- + // Exchange and compare with thread 1 lane away. + // e.g., thread 0 exchanges with 1; thread 2 with 3. + local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 1)); + + // At this point, all threads in a quad hold the maximum value for that quad. +} + +// TODO: Add a template parameter fnor input type. +// For now we just work on float. +// This also assumes a block of 16. That should be a +// template parameter. + +// This assumes that ITEMS_PER_THREAD is 4. +// This assumes for block quantization, the block size is 16. +// This works for float but will extended to work with bfloat. +template +__device__ void block_quantize_to_nvfp4( + Array& input, + Tensor<__e2m1, DIM, DIM>& output, + Tensor<__e4m3, DIM, DIM>& fp8_output) { + assert(blockDim.x % 4 == 0); + assert(blockDim.z == 1 && gridDim.z == 1); + static_assert( + ITEMS_PER_THREAD % 4 == 0, "ITEMS_PER_THREAD must be multiple of 4"); + + Array vec4; + vec4.set(0.0f); // Initialize to zero like nvfuser does + + for (auto i = 0; i < ITEMS_PER_THREAD; i++) { + vec4[i] = input[i]; + } + + float local_max = NEG_INFINITY; +#pragma unroll + for (int i = 0; i < 4; ++i) { + local_max = fmax(local_max, fabsf(vec4[i])); + } + + // Perform block(16 elements)-wide reduction (max) + // across 4- threads + float block_max = NEG_INFINITY; + quadMaxReduction(local_max); + block_max = local_max; + + // This division should be replaced with a multiplication + // by a reciprocal for better performance. + float scaled_max = block_max / 6.000000000e+00f; + float clamped_max = clamp( + scaled_max, 1.562500000e-02f, 4.480000000e+02f); // Clamp between 0 and 1 + + __e4m3 clamped_max_fp8 = __float2e4m3(clamped_max); + + float clamped_max_converted = __e4m32float(clamped_max_fp8); + + // Convert back from FP8 to float using __e4m32float + if (threadIdx.x % 4 == 0) // Only one thread per quad writes + { + int offset_per_cta = (blockDim.x / 4) * blockIdx.x; + int quad_id = threadIdx.x / 4; + fp8_output[offset_per_cta + quad_id] = clamped_max_fp8; + } + + Array clamped_vals; +#pragma unroll + for (int i = 0; i < 4; ++i) { + float scaled_val = vec4[i] / clamped_max_converted; + clamped_vals[i] = clamp(scaled_val, -6.000000000e+00f, 6.000000000e+00f); + } + + Array<__e2m1, 4, 1> fp4_vals; + *reinterpret_cast*>(&fp4_vals[0]) = + __float2e2m1(*reinterpret_cast*>(&clamped_vals[0])); + + Array<__e2m1, 4, 4> fp4_vals_aligned; +#pragma unroll + for (int i = 0; i < 4; ++i) { + fp4_vals_aligned[i] = fp4_vals[i]; + } + + int total_offset = (blockIdx.x * blockDim.x + threadIdx.x) * ITEMS_PER_THREAD; + loadLocalToGlobal<__e2m1, /*vec_size=*/4, /*is_volatile=*/false>( + &output[total_offset], &fp4_vals_aligned.array[0]); +} + +} // namespace bq +} // namespace nvf diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 851a1c1551d..deefadf5720 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -10,6 +10,8 @@ #include #include +#include +#include #include #include @@ -107,13 +109,8 @@ constexpr double F8E4M3_MAX = 448.0; class NVFP4QuantizeTest : public BlackwellBase, public ::testing::WithParamInterface {}; - -TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { - auto data_hp_dtype = GetParam(); - - std::unique_ptr fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - +namespace { +void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { auto tv_data_hp = makeContigTensor(2, data_hp_dtype); fusion->addInput(tv_data_hp); @@ -148,6 +145,16 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { fusion->addOutput(tv_block_scale_fp8); fusion->addOutput(tv_data_lp); +} +} // namespace + +TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { + auto data_hp_dtype = GetParam(); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -168,6 +175,125 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { HeuristicIs(SchedulerType::InnerPersistent))); } +class BQTest : public BlackwellBase {}; + +TEST_F(BQTest, ScheduleAsPointwise) { + // Basic test implementation + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + createNVFP4QunatizationFusion(fusion.get(), DataType::Float); + + FusionExecutorCache fec(std::move(fusion)); + + const int m = 1024; + const int n = 1024; + std::vector inputs; + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); + auto outputs_baseline = fec.runFusionWithInputs(inputs); + + // Print baseline outputs + auto baseline_block_scales = outputs_baseline[0].as(); + auto baseline_quantized_tensor = outputs_baseline[1].as(); + + // Move baseline tensors from GPU to CPU + auto baseline_block_scales_cpu = baseline_block_scales.cpu(); + auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); + + // Print first 32 bytes of baseline block_scales_output in hex format + const uint8_t* baseline_block_scales_data = + static_cast(baseline_block_scales_cpu.data_ptr()); + const uint8_t* baseline_quantized_data = + static_cast(baseline_quantized_tensor_cpu.data_ptr()); + + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion_new_op->addInput(tv_data_hp); + + // t0 is 2D + auto t0 = set(tv_data_hp); + auto quantization_results = block_quantize(t0); + + // t1 and t2 are 2D. + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput(quantization_results.quantized_tensor); + + t0->setMemoryType(MemoryType::Local); + + // This is the 3D input to the BQ Op. + auto view_out_tv = quantization_results.block_scales->definition() + ->input(0) + ->as(); + + for (auto t : + {tv_data_hp, + t0, + view_out_tv, + quantization_results.quantized_tensor, + quantization_results.block_scales}) { + // Merge all dims. + t->merge(-2); + if (t->getLoopDomain().size() >= 2) { + t->merge(-2); + } + + // split by 4. + // I -> I/4, 4 + t->split(-1, 4); + // I//4, 4 -> I/4, 1, 4 + t->split(-2, 1); + // I//4, 1, 4 -> I/512, 128, 1, 4 + t->split(-3, 128); + + if (t != tv_data_hp) { + // Don't vectorize the outputs of reshape + if (t != view_out_tv && t != quantization_results.block_scales) { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } + // I/512(BIDx), 128(TIDx), 1, 4(v) + t->axis(-3)->parallelize(ParallelType::TIDx); + t->axis(-4)->parallelize(ParallelType::BIDx); + } + } + + // Execute the fusion + KernelExecutor ke; + ke.compile(fusion_new_op.get(), inputs); + auto outputs_new_op = ke.run(inputs); + + // Verify we got the expected outputs + auto block_scales_output = outputs_new_op[0].as(); + auto quantized_tensor_output = outputs_new_op[1].as(); + + // Move tensors from GPU to CPU + auto block_scales_cpu = block_scales_output.cpu(); + auto quantized_tensor_cpu = quantized_tensor_output.cpu(); + + auto block_scales_bytes = (m * n) / block_size; + auto quantized_tensor_bytes = (m * n) / 2; + + const uint8_t* block_scales_data = + static_cast(block_scales_cpu.data_ptr()); + for (int i = 0; i < block_scales_bytes; ++i) { + EXPECT_EQ( + block_scales_data[i], + baseline_block_scales_data[i]); // Compare with baseline + } + + const uint8_t* quantized_data = + static_cast(quantized_tensor_cpu.data_ptr()); + for (int i = 0; i < quantized_tensor_bytes; ++i) { + EXPECT_EQ( + quantized_data[i], + baseline_quantized_data[i]); // Compare with baseline + } + + // Basic shape checks + EXPECT_EQ(block_scales_output.dim(), 3); + EXPECT_EQ(quantized_tensor_output.dim(), 3); +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); From 9c06080ef64d888e5e33701fb2151cf6e599f889 Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 29 Sep 2025 12:48:02 -0700 Subject: [PATCH 02/79] removing commented out code --- csrc/device_lower/analysis/sync_information.cpp | 1 - csrc/kernel_ir_dispatch.cpp | 3 --- 2 files changed, 4 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 2914f2ba4ec..988226ee90b 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -505,7 +505,6 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { NVF_ERROR( producer->getMemoryType() == MemoryType::Global || consumer->definition()->isA() || - // producer->definition()->isA() || consumer->uses()[0]->isA(), "Inconsistent parallelization found between T", producer->name(), diff --git a/csrc/kernel_ir_dispatch.cpp b/csrc/kernel_ir_dispatch.cpp index 06759bbbf64..3761cc5faed 100644 --- a/csrc/kernel_ir_dispatch.cpp +++ b/csrc/kernel_ir_dispatch.cpp @@ -24,9 +24,6 @@ void IrVisitor::handle(ForLoop* fl) { scope_exprs_.push_back(fl); auto body_exprs = std::vector(fl->body().exprs()); for (auto expr : body_exprs) { - // if (expr->isA()) { - // continue; - // } dispatch(expr); } scope_exprs_.pop_back(); From fd79ac1ef25808b41579890470f1295f66fad9f9 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 30 Sep 2025 10:24:01 -0700 Subject: [PATCH 03/79] codegen the indices for the outputs --- csrc/codegen.cpp | 4 ++-- csrc/device_lower/pass/index.cpp | 10 ++-------- runtime/block_quantization_kernels.cu | 13 +++++-------- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 1d77bb0a71e..48c51c4dcd0 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1691,8 +1691,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Second argument: quantized output // Third argument: block scale output func_args.arg(ir_utils::varName(bqop->input(0))); - func_args.arg(ir_utils::varName(bqop->quantizedOutput())); - func_args.arg(ir_utils::varName(bqop->blockScales())); // DataT + func_args.arg(genInline(bqop->quantizedOutput())); + func_args.arg(genInline(bqop->blockScales())); indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) << ";\n"; diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index e3823839eeb..0985b9ddaad 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -408,14 +408,8 @@ void IndexLowering::handle(const TopKOp* top) { void IndexLowering::handle(const BlockQuantizationOp* bqop) { const auto in = lowerSrcIndex(bqop->in(), bqop->quantizedOutput()); - // For the two outputs, we don't really need indices. - const auto out_scales = IrBuilder::create( - static_cast(bqop->blockScales()), - IrBuilder::create(0L, DataType::Index)); - - const auto out_quantized = IrBuilder::create( - static_cast(bqop->quantizedOutput()), - IrBuilder::create(0L, DataType::Index)); + const auto out_scales = lowerDstIndex(bqop->blockScales()); + const auto out_quantized = lowerDstIndex(bqop->quantizedOutput()); pushBack( IrBuilder::create(out_scales, out_quantized, in)); diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 83e2b067fbf..001df88f34e 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -35,11 +35,11 @@ __device__ __inline__ void quadMaxReduction(float& local_max) { // This assumes that ITEMS_PER_THREAD is 4. // This assumes for block quantization, the block size is 16. // This works for float but will extended to work with bfloat. -template +template __device__ void block_quantize_to_nvfp4( Array& input, - Tensor<__e2m1, DIM, DIM>& output, - Tensor<__e4m3, DIM, DIM>& fp8_output) { + __e2m1& output, + __e4m3& fp8_output) { assert(blockDim.x % 4 == 0); assert(blockDim.z == 1 && gridDim.z == 1); static_assert( @@ -77,9 +77,7 @@ __device__ void block_quantize_to_nvfp4( // Convert back from FP8 to float using __e4m32float if (threadIdx.x % 4 == 0) // Only one thread per quad writes { - int offset_per_cta = (blockDim.x / 4) * blockIdx.x; - int quad_id = threadIdx.x / 4; - fp8_output[offset_per_cta + quad_id] = clamped_max_fp8; + fp8_output = clamped_max_fp8; // Broadcast to all threads } Array clamped_vals; @@ -99,9 +97,8 @@ __device__ void block_quantize_to_nvfp4( fp4_vals_aligned[i] = fp4_vals[i]; } - int total_offset = (blockIdx.x * blockDim.x + threadIdx.x) * ITEMS_PER_THREAD; loadLocalToGlobal<__e2m1, /*vec_size=*/4, /*is_volatile=*/false>( - &output[total_offset], &fp4_vals_aligned.array[0]); + &output, &fp4_vals_aligned.array[0]); } } // namespace bq From 60af605bbdd366c6d4e1cd0d369de6845ff34d47 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 30 Sep 2025 12:29:25 -0700 Subject: [PATCH 04/79] adding a new test for 2D sched --- tests/cpp/test_low_precision_recipe.cpp | 117 ++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index deefadf5720..c79420c84a7 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -294,6 +294,123 @@ TEST_F(BQTest, ScheduleAsPointwise) { EXPECT_EQ(quantized_tensor_output.dim(), 3); } +TEST_F(BQTest, ScheduleAsPointwise2D) { + // Basic test implementation + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + createNVFP4QunatizationFusion(fusion.get(), DataType::Float); + + FusionExecutorCache fec(std::move(fusion)); + + const int m = 1024; + const int n = 1024; + std::vector inputs; + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); + auto outputs_baseline = fec.runFusionWithInputs(inputs); + + // Print baseline outputs + auto baseline_block_scales = outputs_baseline[0].as(); + auto baseline_quantized_tensor = outputs_baseline[1].as(); + + // Move baseline tensors from GPU to CPU + auto baseline_block_scales_cpu = baseline_block_scales.cpu(); + auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); + + // Print first 32 bytes of baseline block_scales_output in hex format + const uint8_t* baseline_block_scales_data = + static_cast(baseline_block_scales_cpu.data_ptr()); + const uint8_t* baseline_quantized_data = + static_cast(baseline_quantized_tensor_cpu.data_ptr()); + + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion_new_op->addInput(tv_data_hp); + + // t0 is 2D + auto t0 = set(tv_data_hp); + auto quantization_results = block_quantize(t0); + + // t1 and t2 are 2D. + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput(quantization_results.quantized_tensor); + + t0->setMemoryType(MemoryType::Local); + + // This is the 3D input to the BQ Op. + auto view_out_tv = quantization_results.block_scales->definition() + ->input(0) + ->as(); + + for (auto t : {tv_data_hp, t0}) { + t->split(-1, block_size); + } + + for (auto t : + {tv_data_hp, + t0, + view_out_tv, + quantization_results.quantized_tensor, + quantization_results.block_scales}) { + t->merge(1, 2); + + t->split(-1, 4); // V + t->split(-2, 32); // BDx + + t->split(0, 1); + t->split(0, 4); // BDy + + if (t != tv_data_hp) { + // Don't vectorize the outputs of reshape + if (t != view_out_tv && t != quantization_results.block_scales) { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } + // I/512(BIDx), 128(TIDx), 1, 4(v) + t->axis(-2)->parallelize(ParallelType::TIDx); + t->axis(-3)->parallelize(ParallelType::BIDx); + t->axis(-5)->parallelize(ParallelType::TIDy); + t->axis(-6)->parallelize(ParallelType::BIDy); + } + } + + // Execute the fusion + KernelExecutor ke; + ke.compile(fusion_new_op.get(), inputs); + auto outputs_new_op = ke.run(inputs); + + // Verify we got the expected outputs + auto block_scales_output = outputs_new_op[0].as(); + auto quantized_tensor_output = outputs_new_op[1].as(); + + // Move tensors from GPU to CPU + auto block_scales_cpu = block_scales_output.cpu(); + auto quantized_tensor_cpu = quantized_tensor_output.cpu(); + + auto block_scales_bytes = (m * n) / block_size; + auto quantized_tensor_bytes = (m * n) / 2; + + const uint8_t* block_scales_data = + static_cast(block_scales_cpu.data_ptr()); + for (int i = 0; i < block_scales_bytes; ++i) { + EXPECT_EQ( + block_scales_data[i], + baseline_block_scales_data[i]); // Compare with baseline + } + + const uint8_t* quantized_data = + static_cast(quantized_tensor_cpu.data_ptr()); + for (int i = 0; i < quantized_tensor_bytes; ++i) { + EXPECT_EQ( + quantized_data[i], + baseline_quantized_data[i]); // Compare with baseline + } + + // Basic shape checks + EXPECT_EQ(block_scales_output.dim(), 3); + EXPECT_EQ(quantized_tensor_output.dim(), 3); +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); From 73295a1a2ffb3ade181486f31ed5a4568bb2d165 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 30 Sep 2025 17:16:34 -0700 Subject: [PATCH 05/79] write quantized output to regs --- csrc/codegen.cpp | 4 ++-- runtime/block_quantization_kernels.cu | 11 ++++------- tests/cpp/test_low_precision_recipe.cpp | 13 +++++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 48c51c4dcd0..91a074d049e 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1686,12 +1686,12 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // We pass the entire Tensors without any indices. // The device functions will write out values based on - // it's parallelization. + // its parallelization. // First argument: input data array // Second argument: quantized output // Third argument: block scale output func_args.arg(ir_utils::varName(bqop->input(0))); - func_args.arg(genInline(bqop->quantizedOutput())); + func_args.arg(ir_utils::varName(bqop->quantizedOutput())); func_args.arg(genInline(bqop->blockScales())); indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 001df88f34e..8787d78ed2c 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -27,7 +27,7 @@ __device__ __inline__ void quadMaxReduction(float& local_max) { // At this point, all threads in a quad hold the maximum value for that quad. } -// TODO: Add a template parameter fnor input type. +// TODO: Add a template parameter for input type. // For now we just work on float. // This also assumes a block of 16. That should be a // template parameter. @@ -38,7 +38,7 @@ __device__ __inline__ void quadMaxReduction(float& local_max) { template __device__ void block_quantize_to_nvfp4( Array& input, - __e2m1& output, + Array<__e2m1, ITEMS_PER_THREAD, ITEMS_PER_THREAD>& output, __e4m3& fp8_output) { assert(blockDim.x % 4 == 0); assert(blockDim.z == 1 && gridDim.z == 1); @@ -91,14 +91,11 @@ __device__ void block_quantize_to_nvfp4( *reinterpret_cast*>(&fp4_vals[0]) = __float2e2m1(*reinterpret_cast*>(&clamped_vals[0])); - Array<__e2m1, 4, 4> fp4_vals_aligned; + // Array<__e2m1, 4, 4> fp4_vals_aligned; #pragma unroll for (int i = 0; i < 4; ++i) { - fp4_vals_aligned[i] = fp4_vals[i]; + output[i] = fp4_vals[i]; } - - loadLocalToGlobal<__e2m1, /*vec_size=*/4, /*is_volatile=*/false>( - &output, &fp4_vals_aligned.array[0]); } } // namespace bq diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index c79420c84a7..593a8a22697 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -214,10 +214,11 @@ TEST_F(BQTest, ScheduleAsPointwise) { // t0 is 2D auto t0 = set(tv_data_hp); auto quantization_results = block_quantize(t0); + auto t_out = set(quantization_results.quantized_tensor); // t1 and t2 are 2D. fusion_new_op->addOutput(quantization_results.block_scales); - fusion_new_op->addOutput(quantization_results.quantized_tensor); + fusion_new_op->addOutput(t_out); t0->setMemoryType(MemoryType::Local); @@ -231,7 +232,8 @@ TEST_F(BQTest, ScheduleAsPointwise) { t0, view_out_tv, quantization_results.quantized_tensor, - quantization_results.block_scales}) { + quantization_results.block_scales, + t_out}) { // Merge all dims. t->merge(-2); if (t->getLoopDomain().size() >= 2) { @@ -331,10 +333,11 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { // t0 is 2D auto t0 = set(tv_data_hp); auto quantization_results = block_quantize(t0); + auto t_out = set(quantization_results.quantized_tensor); // t1 and t2 are 2D. fusion_new_op->addOutput(quantization_results.block_scales); - fusion_new_op->addOutput(quantization_results.quantized_tensor); + fusion_new_op->addOutput(t_out); t0->setMemoryType(MemoryType::Local); @@ -343,6 +346,7 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { ->input(0) ->as(); + // split the intput 2D tensor to make then 3D. for (auto t : {tv_data_hp, t0}) { t->split(-1, block_size); } @@ -352,7 +356,8 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { t0, view_out_tv, quantization_results.quantized_tensor, - quantization_results.block_scales}) { + quantization_results.block_scales, + t_out}) { t->merge(1, 2); t->split(-1, 4); // V From f97264af06174976e055bff05d0d2083dca646a1 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 30 Sep 2025 17:56:51 -0700 Subject: [PATCH 06/79] clean up --- csrc/device_lower/pass/allocation.cpp | 1 - csrc/device_lower/pass/predicate.cpp | 8 ++++---- csrc/device_lower/validation.cpp | 11 ++++++----- csrc/ir/utils.cpp | 7 ------- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 038873f73f5..b28bd3e3914 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -251,7 +251,6 @@ class AllocationDomainSetup : private kir::IrVisitor { // aliasTensorProducer, in which case it will not be allocated. NVF_ERROR( producer_tv->isFusionInput() || - producer_tv->definition()->isA() || GpuLower::current()->getTensorProducerAlias(producer_tv) != nullptr, "Expected a fusion input or aliased tensor but found: ", diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 73667b55650..d2cc8189c31 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -69,15 +69,15 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { vec_expr->isA() || vec_expr->isA() || vec_expr->isA() || vec_expr->isA() || + // To supress the throw. + // I think this is predicated on the vectorized dim. vec_expr->isA(), "Vectorize predicate exprs only supported on set operations."); NVF_ERROR( - ir_utils::isTvOp(vec_expr) || - vec_expr->isA(), + ir_utils::isTvOp(vec_expr), "Vectorize predicate exprs only supported on tensor view " "operations."); - if (!vec_expr->inputs()[0]->isConstScalar() && - !vec_expr->isA()) { + if (!vec_expr->inputs()[0]->isConstScalar()) { conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, GpuLower::current()->info().threadPredicateMap().getPredicate( diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index d1f2c2edff9..f1674abff68 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -50,7 +50,7 @@ class ValidateSiblings : public IterVisitor { // Skip BlockQuantization. // It has sibling outputs which differ from each other if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2 || - expr->isA()) { + !ir_utils::hasUniformSiblings(expr)) { IterVisitor::dispatch(expr); return; } @@ -710,12 +710,10 @@ class VectorizeValidator : public OptInDispatch { } auto ldst = dynamic_cast(tv->definition()); - auto is_block_quantization_op = - dynamic_cast(tv->definition()); + bool is_ldmatrix_trans = ldst != nullptr && mma_utils::isLdMatrixTranspose(ldst); - if (!is_ldmatrix_trans && name.compare("consumer") != 0 && - !is_block_quantization_op) { + if (!is_ldmatrix_trans && name.compare("consumer") != 0) { // ldmatrix.trans is a hardware transpose instruction that can do // "vectorized" read from discontiguous memory // We don't think allocation domain of consumer is used in allocation. We @@ -1023,6 +1021,9 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { def->as()->serialGridReductionRequested()) || (def->isA() && def->as()->getUnaryOpType() == UnaryOpType::Cast) || + // This throws without this check. + // Maybe I shouldn't vectorize the outputs of the + // BlockQuantizationOp def->isA(), "Vectorized accesses cannot be inline with computation: ", (def == nullptr ? tv->toString() : def->toString())); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index ffec2e039bf..ec254316382 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1516,13 +1516,6 @@ kir::ForLoop* createRangeLoop(int64_t size) { } TensorView* getTvOutput(const Expr* expr) { - if (expr->isA()) { - // BlockQuantizationOp has multiple outputs - // but for now we only look at the quantized output - // which cleanly maps to the input. - return getTv(expr->as()->quantizedOutput()); - } - for (auto out : expr->outputs()) { if (auto tv = getTv(out)) { return tv; From a587641a062eb25b5ce54f44d8e7524668abac25 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 30 Sep 2025 18:39:00 -0700 Subject: [PATCH 07/79] clean up --- csrc/device_lower/analysis/sync_information.cpp | 9 +++------ csrc/ops/arith.cpp | 3 +-- csrc/ops/arith.h | 2 +- tests/cpp/test_low_precision_recipe.cpp | 4 ++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 988226ee90b..e5316a704b4 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -232,10 +232,6 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { continue; } - if (expr->isA() && producer_i == producer->nDims() - 2) { - continue; - } - producer_parallel_ids[getParallelTypeBitMapOffset(producer_ptype)] = producer_axis; } @@ -503,6 +499,9 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { if (error_on_failure) { if (raw_dims.hasBID()) { NVF_ERROR( + // We need to allow the output of block quantization and the + // outputs of the reshape preceding the block + // quantization to be in local memory. producer->getMemoryType() == MemoryType::Global || consumer->definition()->isA() || consumer->uses()[0]->isA(), @@ -522,8 +521,6 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { NVF_ERROR( ir_utils::isLdMatrixOp(producer->definition()) || ir_utils::isStMatrixOp(consumer->definition()) || - consumer->definition()->isA() || - producer->definition()->isA() || producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared || producer->getMemoryType() == MemoryType::Tensor, diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 34c46b5a088..59bee9befa0 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2634,7 +2634,7 @@ TensorView* prefixSum(TensorView* tv, int64_t dim) { /*init=*/tv->fusion()->zeroVal(tv->dtype())); } -BlockQuantizationResults block_quantize(TensorView* input) { +BlockQuantizationResults blockQuantize(TensorView* input) { auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); auto inp_domain = @@ -2671,7 +2671,6 @@ BlockQuantizationResults block_quantize(TensorView* input) { IterDomainBuilder( input->fusion()->zeroVal(), input->fusion()->oneVal()) .iter_type(IterType::Broadcast) - .expanded_extent(IrBuilder::create(1, DataType::Index)) .build()); } else { scales_out_domain.push_back(inp_domain[i]->cloneWithoutRFactor()); diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 45c75b157ad..a74d14a26fe 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -835,6 +835,6 @@ struct BlockQuantizationResults { }; //! Expose block size as a parameter. Currently only supports 16. -NVF_API BlockQuantizationResults block_quantize(TensorView* input); +NVF_API BlockQuantizationResults blockQuantize(TensorView* input); } // namespace nvfuser diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 593a8a22697..4fbcb520213 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -213,7 +213,7 @@ TEST_F(BQTest, ScheduleAsPointwise) { // t0 is 2D auto t0 = set(tv_data_hp); - auto quantization_results = block_quantize(t0); + auto quantization_results = blockQuantize(t0); auto t_out = set(quantization_results.quantized_tensor); // t1 and t2 are 2D. @@ -332,7 +332,7 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { // t0 is 2D auto t0 = set(tv_data_hp); - auto quantization_results = block_quantize(t0); + auto quantization_results = blockQuantize(t0); auto t_out = set(quantization_results.quantized_tensor); // t1 and t2 are 2D. From a44a6fcfe007055dd78d7c7b7fff0117904607fa Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 1 Oct 2025 09:51:11 -0700 Subject: [PATCH 08/79] clean up trivial broadcast --- .../analysis/trivial_broadcast.cpp | 28 +++++++++---------- .../device_lower/analysis/trivial_broadcast.h | 2 ++ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 0164df2f521..7afb62ad2c7 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -20,12 +20,6 @@ ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { auto inputs = fusion->inputsAndCreated(); auto exprs_ = fusion->exprs(); - auto bq_ops = ir_utils::filterByType(exprs_); - if (bq_ops.size() == 1) { - inputs.push_back( - static_cast(bq_ops.vector()[0]->blockScales())); - } - for (const auto fusion_input_tv : ir_utils::filterByType(inputs)) { for (auto logical_id : fusion_input_tv->getLogicalDomain()) { @@ -110,6 +104,16 @@ void ConcretizedBroadcastDomains::handle(TopKOp* top) { } } +// BlockQuantizationOp introduces broadcast domains in the block scales output +void ConcretizedBroadcastDomains::handle(BlockQuantizationOp* bq) { + auto out = bq->blockScales()->as(); + auto bcast_id = out->getLogicalDomain().back(); + if (bcast_id->isBroadcast()) { + broadcast_origin_map_.emplace( + bcast_id, std::unordered_set({bcast_id})); + } +} + void ConcretizedBroadcastDomains::dispatch(Expr* expr) { IterVisitor::dispatch(expr); @@ -131,9 +135,6 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { for (auto consumer : ir_utils::filterByType(expr->outputs())) { auto p2c_map = PairwiseLogicalDomainMap(producer, consumer) .mapProducerToConsumer(&producer_broadcasts); - auto consumer_is_block_quantization_scales = - expr->isA() && - consumer == expr->as()->blockScales(); for (const auto& kv : p2c_map) { auto p_id = kv.first; @@ -144,8 +145,7 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { !c_id->isBroadcast() && !c_id->isReduction(); auto it = broadcast_origin_map_.find(p_id); NVF_ERROR( - it != broadcast_origin_map_.end() && - !consumer_is_block_quantization_scales, + it != broadcast_origin_map_.end(), "Broadcast origin info not found for producer broadcast domain: ", p_id->toString(), " of ", @@ -159,10 +159,8 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { } else { // Not concretized yet. Propagate forward the origin info. auto& consumer_origins = broadcast_origin_map_[c_id]; - if (!consumer_is_block_quantization_scales) { - for (auto origin : producer_origins) { - consumer_origins.insert(origin); - } + for (auto origin : producer_origins) { + consumer_origins.insert(origin); } consumer_origins.insert(c_id); } diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index 4015737484c..b002d73e976 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -51,6 +51,8 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor { void handle(TopKOp* top) final; + void handle(BlockQuantizationOp* bq) final; + void dispatch(Expr* expr) final; void markAsConcretized( From 727b663328497eeccdf8799eb5a413bc745b1297 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 1 Oct 2025 10:04:02 -0700 Subject: [PATCH 09/79] clean up the tests --- tests/cpp/test_low_precision_recipe.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 4fbcb520213..e38e2ef179e 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -216,7 +216,6 @@ TEST_F(BQTest, ScheduleAsPointwise) { auto quantization_results = blockQuantize(t0); auto t_out = set(quantization_results.quantized_tensor); - // t1 and t2 are 2D. fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); @@ -278,17 +277,13 @@ TEST_F(BQTest, ScheduleAsPointwise) { const uint8_t* block_scales_data = static_cast(block_scales_cpu.data_ptr()); for (int i = 0; i < block_scales_bytes; ++i) { - EXPECT_EQ( - block_scales_data[i], - baseline_block_scales_data[i]); // Compare with baseline + EXPECT_EQ(block_scales_data[i], baseline_block_scales_data[i]); } const uint8_t* quantized_data = static_cast(quantized_tensor_cpu.data_ptr()); for (int i = 0; i < quantized_tensor_bytes; ++i) { - EXPECT_EQ( - quantized_data[i], - baseline_quantized_data[i]); // Compare with baseline + EXPECT_EQ(quantized_data[i], baseline_quantized_data[i]); } // Basic shape checks @@ -318,7 +313,6 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { auto baseline_block_scales_cpu = baseline_block_scales.cpu(); auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); - // Print first 32 bytes of baseline block_scales_output in hex format const uint8_t* baseline_block_scales_data = static_cast(baseline_block_scales_cpu.data_ptr()); const uint8_t* baseline_quantized_data = @@ -335,7 +329,7 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { auto quantization_results = blockQuantize(t0); auto t_out = set(quantization_results.quantized_tensor); - // t1 and t2 are 2D. + // outputs are 3D fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); @@ -346,7 +340,8 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { ->input(0) ->as(); - // split the intput 2D tensor to make then 3D. + // split the input 2D tensor to make then 3D. + // (i0, i1) -> (i0, i1//block_size, block_size) for (auto t : {tv_data_hp, t0}) { t->split(-1, block_size); } @@ -358,20 +353,25 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { quantization_results.quantized_tensor, quantization_results.block_scales, t_out}) { + // (m, n, k) -> (m, n*k) t->merge(1, 2); + // (m, n*k) -> (m, n*k/4, 4) + // (m, n*k/4, 4) -> (m, n*k/128, 32, 4) t->split(-1, 4); // V t->split(-2, 32); // BDx + // (m, n*k/128, 32, 4) -> (m, 1, n*k/128, 32, 4) + // (m, 1, n*k/128, 32, 4) -> (m/4, 4, 1, n*k/128, 32, 4) t->split(0, 1); - t->split(0, 4); // BDy + t->split(0, 4); + // (m/4(bidy), 4(tidy), 1, n*k/128(bidx), 32(tidx), 49(v)) if (t != tv_data_hp) { // Don't vectorize the outputs of reshape if (t != view_out_tv && t != quantization_results.block_scales) { t->axis(-1)->parallelize(ParallelType::Vectorize); } - // I/512(BIDx), 128(TIDx), 1, 4(v) t->axis(-2)->parallelize(ParallelType::TIDx); t->axis(-3)->parallelize(ParallelType::BIDx); t->axis(-5)->parallelize(ParallelType::TIDy); From e8b0adeb3752ed2d910e587f352285fe5a0cf3ba Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 1 Oct 2025 10:08:56 -0700 Subject: [PATCH 10/79] minor cleanup --- csrc/device_lower/analysis/trivial_broadcast.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 7afb62ad2c7..6cb2055eb21 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -18,8 +18,6 @@ ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { // Initialize the origin map with input broadcast domains auto inputs = fusion->inputsAndCreated(); - auto exprs_ = fusion->exprs(); - for (const auto fusion_input_tv : ir_utils::filterByType(inputs)) { for (auto logical_id : fusion_input_tv->getLogicalDomain()) { From 55bb2c376cd1d033097ba59ec498c1c7fc58b612 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 7 Oct 2025 08:59:43 -0700 Subject: [PATCH 11/79] address reviewer comments --- csrc/codegen.cpp | 106 +++++++++++++++--------- csrc/ir/internal_nodes.h | 20 ++++- csrc/ir/nodes.cpp | 16 ++-- csrc/ops/arith.cpp | 54 +++++++++--- csrc/ops/arith.h | 15 ++-- csrc/scheduler/registry.cpp | 1 - tests/cpp/test_low_precision_recipe.cpp | 2 - 7 files changed, 147 insertions(+), 67 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 91a074d049e..b78cc22bd39 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -471,8 +471,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto space_type = kernel_summary.largest_smem_data_type; indent() << "nvfuser_index_t block_size = " "blockDim.x*blockDim.y*blockDim.z;\n"; - indent() << space_type << " *shared_mem_var = " << "static_cast<" - << space_type << "*>(" << "shared_mem);\n"; + indent() << space_type << " *shared_mem_var = " + << "static_cast<" << space_type << "*>(" + << "shared_mem);\n"; indent() << space_type << " *shared_mem_avg = shared_mem_var + block_size;\n"; indent() << space_type @@ -1355,9 +1356,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { case BinaryOpType::Add: if (sop->in()->dtype() == DataType::Int) { // atomicAdd does not provide an overload for int64_t - code_ << "atomicAdd(" << "reinterpret_cast(&" - << dst << "), " << "static_cast(" << src - << "));\n"; + code_ << "atomicAdd(" + << "reinterpret_cast(&" << dst << "), " + << "static_cast(" << src << "));\n"; } else { code_ << "atomicAdd(" << "&" << dst << ", " << src << ");\n"; } @@ -1667,37 +1668,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall("topk::blockTopK", template_args, func_args) << ";\n"; } - void handle(const BlockQuantizationOp* bqop) final { - // Get the vectorization size for items per thread - const auto input = bqop->in()->as(); - auto vectorized_input_to_reshape = input->view()->definition()->input(0); - int64_t vector_word_size = ir_utils::getVectorizeSize( - vectorized_input_to_reshape->as()); - NVF_ERROR( - vector_word_size == 4, - "Vectorization size should be 4 for " - "BlockQuantizationOp: ", - bqop->toString()); - ArgumentBuilder template_args; - template_args.arg(vector_word_size); // ITEMS_PER_THREAD - - // Function arguments - ArgumentBuilder func_args; - - // We pass the entire Tensors without any indices. - // The device functions will write out values based on - // its parallelization. - // First argument: input data array - // Second argument: quantized output - // Third argument: block scale output - func_args.arg(ir_utils::varName(bqop->input(0))); - func_args.arg(ir_utils::varName(bqop->quantizedOutput())); - func_args.arg(genInline(bqop->blockScales())); - - indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) - << ";\n"; - } - void handle(const ScanOp* scan) final { NVF_ERROR(isAligned(), "Scan with divergent threads not supported"); @@ -1779,7 +1749,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // This is slightly different from getReductionOp std::stringstream lambda; lambda << "[](const " << input->dtype() << "& a, const " << input->dtype() - << "& b) " << "{ return " + << "& b) " + << "{ return " << genBinaryOp(scan->opType(), input->dtype(), "a", "b") << "; }"; func_args.arg(lambda.str()); @@ -1789,6 +1760,57 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall("scan::blockScan", template_args, func_args) << ";\n"; } + // Special handling of BlockQuantizationOp to call the runtime function. + // TODO: add support for global scaling factor + // TODO: make sure we can handle BF16 input + void handle(const BlockQuantizationOp* bqop) final { + // Block quantization takes in a reshaped input + // so a (m, k) input becomes (m, k/16, 16) + // The 16 is the block size and is "reduced". + // For FP32 input each thread locally reduces 4 elements and a quad of + // thread reduces 16. For BF16 input each thread locally reduces 8 elements + // and a pair of thread reduces 16. The device function for block + // quantization. Due to these assumptions we have to have the inner + // dimension of the bqop input (and outputs) vectorized by 4 or 8 for the + // current device function to work. + auto output = bqop->quantizedOutput()->as()->view(); + int64_t vector_word_size = ir_utils::getVectorizeSize(output); + + auto input_dtype = + bqop->in()->as()->view()->getDataType(); + + if (input_dtype == DataType::BFloat16) { + NVF_ERROR( + vector_word_size == 8, + "Vectorization size should be 8 for " + "BlockQuantizationOp: ", + bqop->toString()); + + } else { + NVF_ERROR( + vector_word_size == 4, + "Vectorization size should be 4 for " + "BlockQuantizationOp: ", + bqop->toString()); + } + + ArgumentBuilder template_args; + template_args.arg(vector_word_size); // ITEMS_PER_THREAD + + // Function arguments + ArgumentBuilder func_args; + + // First argument: input data array + // Second argument: quantized output + // Third argument: block scale output + func_args.arg(genInline(bqop->input(0)->as()->view())); + func_args.arg(genInline(output)); + func_args.arg(genInline(bqop->blockScales())); + + indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) + << ";\n"; + } + std::string genReductionOp(BinaryOpType op_type, DataType data_type) { std::stringstream lambda; lambda << "[](" << data_type << " &a, " << data_type << " b) " @@ -2118,8 +2140,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "#pragma unroll\n"; indent() << "for (int i = 0; i < " << ldst->groupSize() << "; ++i) {\n"; indent() << kTab << genVariableName(out_ti->view()) << "[(" - << genInline(out_ti->index()) << ") + i]" << " = " - << gen(ldst->in()) << ";\n"; + << genInline(out_ti->index()) << ") + i]" + << " = " << gen(ldst->in()) << ";\n"; indent() << "}\n"; } @@ -2226,7 +2248,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { const bool has_grid_reduce = domain->hasGridReduction(); if (!has_block_reduce && !has_grid_reduce) { - indent() << "welfordCombine (" << "\n"; + indent() << "welfordCombine (" + << "\n"; indent() << kTab << gen(out_avg) << ",\n"; indent() << kTab << gen(out_var) << ",\n"; indent() << kTab << gen(out_N) << ",\n"; @@ -4160,7 +4183,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // actual argument value like T0[i * 4 + j]. << (as_utility ? prefix + std::to_string(counter) : gen(register_)) - << "[" << i << "]" << ")"; + << "[" << i << "]" + << ")"; } } else { (*asm_target) << "\"" << constraint << "\"("; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index bdc9d6a23c8..2a5a7e4da52 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -3407,7 +3407,9 @@ class BlockQuantizationOp : public Expr { IrBuilderPasskey, Val* output_scales, Val* output, - Val* input); + Val* input, + Val* global_scale = nullptr, + int64_t block_size = 16); NVFUSER_DECLARE_CLONE_AND_CREATE @@ -3424,7 +3426,21 @@ class BlockQuantizationOp : public Expr { } int64_t blockSize() const { - return 16; + return attribute(0); + } + + bool hasGlobalScale() const { + if (inputs().size() > 1) { + return true; + } + return false; + } + + Val* globalScale() const { + if (hasGlobalScale()) { + return input(1); + } + return nullptr; } const char* getOpString() const override { diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 57a280975ce..7641db28788 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -6302,18 +6302,24 @@ BlockQuantizationOp::BlockQuantizationOp( IrBuilderPasskey passkey, Val* output_scales, Val* output, - Val* input) + Val* input, + Val* global_scale, + int64_t block_size) : Expr(passkey) { - addInput(input); addOutput(output); addOutput(output_scales); + addInput(input); + if (global_scale) { + addInput(global_scale); + } + addDataAttribute(block_size); } std::string BlockQuantizationOp::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "(" << blockScales()->toString() << ", " - << quantizedOutput()->toString() - << ") = block_quantize(" << in()->toString() << ")\n"; + indent(ss, indent_size) << "(" << blockScales()->toString() << ",\n " + << quantizedOutput()->toString() << ")\n" + << " = block_quantize(" << in()->toString() << ")\n"; return ss.str(); } diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 59bee9befa0..a74d8992c46 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2634,7 +2634,45 @@ TensorView* prefixSum(TensorView* tv, int64_t dim) { /*init=*/tv->fusion()->zeroVal(tv->dtype())); } -BlockQuantizationResults blockQuantize(TensorView* input) { +// API for block quantization to nvFP4. +// We take FP32 or BF16 input and produce two outputs +// nvFP4(x2) outputs and FP8 block scales. +// The input is first reshape to have an inner 16 dimension. +// This 16 should be configurable but fixed for now. +// So, if the input is of shape (m, k), it is first reshaped +// to (m, k/16, 16). The quantized output has the shape (m, k/16, 16) +// and block scales has the shape (m, k/16, b(1)), where the inner dimension +// had been "reduced" (max). Please note there is no actual reduction and this +// node should be handled by the pointwise scheduler. +// Currently this node get lowered to a runtime function, which expects the +// input in registers and write out quantized values or registers and block +// scales to global memory. +BlockQuantizationResults blockQuantize( + TensorView* input, + int64_t block_size, + DataType out_dtype) { + NVF_CHECK( + block_size == 16, + "Currently only block size of 16 is supported, got ", + block_size); + + NVF_CHECK( + out_dtype == DataType::Float4_e2m1fn_x2, + "Currently only output data type of Float4_e2m1fn_x2 is supported"); + + // Validate input data type + // WE'll only support FP32 or BF16 + // TODO: BF16 + NVF_CHECK( + isFloatingPointType(input->getDataType().value()), + "Block quantization expects floating point input but got ", + input->getDataType().value()); + + // We reshape the input to the keep the block size as the inner dimension + // that will be "reduced". For example, if our input in [m, k] and the block + // size is 16. Then we need to compute the max of the inner-most 16 elements. + // Thus, we first reshape the input to [m, k/16, 16] and then compute the max + // over the inner-most 16 elements. auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); auto inp_domain = @@ -2645,12 +2683,6 @@ BlockQuantizationResults blockQuantize(TensorView* input) { !inp_domain.empty(), "Block quantization does not support zero-dimensional tensors"); - // Validate input data type - typically requires floating point input - NVF_CHECK( - isFloatingPointType(input->getDataType().value()), - "Block quantization expects floating point input but got ", - input->getDataType().value()); - // Create output domain for quantized tensor (same shape as input) std::vector quantized_out_domain; quantized_out_domain.reserve(inp_domain.size()); @@ -2660,8 +2692,10 @@ BlockQuantizationResults blockQuantize(TensorView* input) { } // Create output domain for block scales - // Block scales typically have reduced dimensions based on block size - // For now, assuming block size of 16 and reducing the last dimension + // If the input after reshape is [m, k/16, 16] then + // block scales will be [m, k/16, b(1)] + // We keep the inner-dimension as b(1) to make scheduling easier. + // Both the ouputs of quantization can be scheduled in similar fashion. std::vector scales_out_domain; scales_out_domain.reserve(inp_domain.size()); @@ -2694,7 +2728,7 @@ BlockQuantizationResults blockQuantize(TensorView* input) { IrBuilder::create( block_scales, quantized_tensor, reshaped_input); - return BlockQuantizationResults(block_scales, quantized_tensor); + return BlockQuantizationResults(quantized_tensor, block_scales); } } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index a74d14a26fe..6d27674d199 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -825,16 +825,19 @@ NVF_API inline TensorView* cumsum(TensorView* tv, int64_t dim) { struct BlockQuantizationResults { public: - TensorView* block_scales = nullptr; TensorView* quantized_tensor = nullptr; + TensorView* block_scales = nullptr; explicit BlockQuantizationResults( - TensorView* in_block_scales, - TensorView* in_quantized_tensor) - : block_scales(in_block_scales), quantized_tensor(in_quantized_tensor) {} + TensorView* in_quantized_tensor, + TensorView* in_block_scales) + : quantized_tensor(in_quantized_tensor), block_scales(in_block_scales) {} }; -//! Expose block size as a parameter. Currently only supports 16. -NVF_API BlockQuantizationResults blockQuantize(TensorView* input); +//! TODO: Expose global scaling factor +NVF_API BlockQuantizationResults blockQuantize( + TensorView* input, + int64_t block_size = 16, + DataType out_dtype = DataType::Float4_e2m1fn_x2); } // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index de767d13b15..9f5e4460fa7 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -60,7 +60,6 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) { // TODO: remove this once we have a scheduler for it PreprocessGroupedMatmulInputSf, TopKOp, - BlockQuantizationOp, ScanOp>(fusion)) { scheduler_debug_utils::canScheduleRejectReason( scheduler_type, "Has unsupported ops"); diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index e38e2ef179e..4f3925dc89e 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -219,8 +219,6 @@ TEST_F(BQTest, ScheduleAsPointwise) { fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); - t0->setMemoryType(MemoryType::Local); - // This is the 3D input to the BQ Op. auto view_out_tv = quantization_results.block_scales->definition() ->input(0) From 69c315e08d77fdb2c92a0329b3aea2142bfdb4e6 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 7 Oct 2025 16:52:49 -0700 Subject: [PATCH 12/79] reviewer comment --- csrc/codegen.cpp | 3 +- .../analysis/non_divisible_split.cpp | 4 +- csrc/device_lower/pass/index.cpp | 14 ++-- csrc/logical_domain_map.cpp | 11 +++ csrc/ops/arith.cpp | 22 ++++-- runtime/block_quantization_kernels.cu | 69 +++++++++++++------ tests/cpp/test_low_precision_recipe.cpp | 49 +++---------- 7 files changed, 100 insertions(+), 72 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index b78cc22bd39..cd2c1320db1 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1805,7 +1805,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Third argument: block scale output func_args.arg(genInline(bqop->input(0)->as()->view())); func_args.arg(genInline(output)); - func_args.arg(genInline(bqop->blockScales())); + func_args.arg( + genInline(bqop->blockScales()->as()->view())); indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) << ";\n"; diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index fcf0144fda5..5ca07d0f57e 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -212,7 +212,9 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { } auto def = tv->definition(); - if (def == nullptr) { + // A block quantization is plumbed down to a runtime function. + // The runtime function handles predication so skip this. + if (def == nullptr || def->isA()) { continue; } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 0985b9ddaad..3353fb445c2 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -406,10 +406,16 @@ void IndexLowering::handle(const TopKOp* top) { } void IndexLowering::handle(const BlockQuantizationOp* bqop) { - const auto in = lowerSrcIndex(bqop->in(), bqop->quantizedOutput()); - - const auto out_scales = lowerDstIndex(bqop->blockScales()); - const auto out_quantized = lowerDstIndex(bqop->quantizedOutput()); + // const auto in = lowerSrcIndex(bqop->in(), bqop->quantizedOutput()); + const auto in = IrBuilder::create( + bqop->in()->as(), bqop->fusion()->zeroVal()); + + const auto out_scales = IrBuilder::create( + bqop->blockScales()->as(), + bqop->fusion()->zeroVal()); // lowerDstIndex(bqop->blockScales()); + const auto out_quantized = IrBuilder::create( + bqop->quantizedOutput()->as(), + bqop->fusion()->zeroVal()); // lowerDstIndex(bqop->quantizedOutput()); pushBack( IrBuilder::create(out_scales, out_quantized, in)); diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 7850d214fa1..d1e45f8d42b 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -154,6 +154,17 @@ std::pair, bool> getNonMappingDomainInfo( } has_consumer_id = true; } + } else if ( + auto bqop = + dynamic_cast(consumer_tv->definition())) { + if (producer_tv == bqop->in()) { + auto producer_logical = + TensorDomain::noReductions(producer_tv->getLogicalDomain()); + auto last_logical_dim = producer_tv->getLogicalDomain().size() - 1; + non_mapping_ids.insert(producer_logical.at(last_logical_dim)); + // We are mapping everything but the last ID. + has_consumer_id = true; + } } return std::make_pair(non_mapping_ids, has_consumer_id); diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index a74d8992c46..e3a929de0d7 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2673,10 +2673,11 @@ BlockQuantizationResults blockQuantize( // size is 16. Then we need to compute the max of the inner-most 16 elements. // Thus, we first reshape the input to [m, k/16, 16] and then compute the max // over the inner-most 16 elements. - auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); + // auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); auto inp_domain = - TensorDomain::noReductions(reshaped_input->getLogicalDomain()); + // TensorDomain::noReductions(reshaped_input->getLogicalDomain()); + TensorDomain::noReductions(input->getLogicalDomain()); // Validate input tensor is not zero-dimensional NVF_CHECK( @@ -2701,11 +2702,21 @@ BlockQuantizationResults blockQuantize( for (size_t i = 0; i < inp_domain.size(); ++i) { if (i == inp_domain.size() - 1) { + // scales_out_domain.push_back( + // IterDomainBuilder( + // input->fusion()->zeroVal(), input->fusion()->oneVal()) + // .iter_type(IterType::Broadcast) + // .build()); + + // Close inp_domain[i] and divide by 16 to create new iter domain scales_out_domain.push_back( IterDomainBuilder( - input->fusion()->zeroVal(), input->fusion()->oneVal()) - .iter_type(IterType::Broadcast) + inp_domain[i]->start(), + SimplifyingIrBuilder::divExpr( + inp_domain[i]->extent(), + IrBuilder::create(16L, DataType::Index))) .build()); + } else { scales_out_domain.push_back(inp_domain[i]->cloneWithoutRFactor()); } @@ -2725,8 +2736,7 @@ BlockQuantizationResults blockQuantize( DataType::Float8_e4m3fn); // Scales maintain input data type // Create the block quantization operation - IrBuilder::create( - block_scales, quantized_tensor, reshaped_input); + IrBuilder::create(block_scales, quantized_tensor, input); return BlockQuantizationResults(quantized_tensor, block_scales); } diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 8787d78ed2c..e3de211fb45 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,7 +9,8 @@ namespace nvf { namespace bq { -__device__ __inline__ void quadMaxReduction(float& local_max) { +template +__device__ __inline__ void localMaxReduction(float& local_max) { // The mask 0xffffffff indicates all 32 threads in the warp are participating. unsigned int mask = 0xffffffff; @@ -17,7 +18,9 @@ __device__ __inline__ void quadMaxReduction(float& local_max) { // Exchange and compare with thread 2 lanes away within the quad. // e.g., thread 0 exchanges with 2; thread 1 with 3. // The XOR pattern naturally keeps the operation within each quad. - local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 2)); + if (std::is_same::value) { + local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 2)); + } // --- Reduction Step 2 --- // Exchange and compare with thread 1 lane away. @@ -35,33 +38,47 @@ __device__ __inline__ void quadMaxReduction(float& local_max) { // This assumes that ITEMS_PER_THREAD is 4. // This assumes for block quantization, the block size is 16. // This works for float but will extended to work with bfloat. -template +template < + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> __device__ void block_quantize_to_nvfp4( - Array& input, - Array<__e2m1, ITEMS_PER_THREAD, ITEMS_PER_THREAD>& output, - __e4m3& fp8_output) { + Array& input, + Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output) { assert(blockDim.x % 4 == 0); assert(blockDim.z == 1 && gridDim.z == 1); static_assert( ITEMS_PER_THREAD % 4 == 0, "ITEMS_PER_THREAD must be multiple of 4"); - Array vec4; - vec4.set(0.0f); // Initialize to zero like nvfuser does + Array vec_in; + vec_in.set(0.0f); // Initialize to zero like nvfuser does for (auto i = 0; i < ITEMS_PER_THREAD; i++) { - vec4[i] = input[i]; + if constexpr (std::is_same::value) { + vec_in[i] = input[i]; + } else if constexpr (std::is_same::value) { + vec_in[i] = __bfloat2float(input[i]); + } else { + static_assert( + std::is_same::value || std::is_same::value, + "Unsupported type"); + } } float local_max = NEG_INFINITY; #pragma unroll - for (int i = 0; i < 4; ++i) { - local_max = fmax(local_max, fabsf(vec4[i])); + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + local_max = fmax(local_max, fabsf(vec_in[i])); } // Perform block(16 elements)-wide reduction (max) // across 4- threads float block_max = NEG_INFINITY; - quadMaxReduction(local_max); + localMaxReduction(local_max); block_max = local_max; // This division should be replaced with a multiplication @@ -74,26 +91,34 @@ __device__ void block_quantize_to_nvfp4( float clamped_max_converted = __e4m32float(clamped_max_fp8); + int offset_y_blocks = blockIdx.y * blockDim.y * blockDim.x * gridDim.x; + int offset_dim_y = threadIdx.y * blockDim.x * gridDim.x; + int offset_into_block = blockIdx.x * blockDim.x + threadIdx.x; + + int offset = (offset_y_blocks + offset_dim_y + offset_into_block) / 4; + // Convert back from FP8 to float using __e4m32float - if (threadIdx.x % 4 == 0) // Only one thread per quad writes + if (threadIdx.x % ITEMS_PER_THREAD == 0) // Only one thread per quad writes { - fp8_output = clamped_max_fp8; // Broadcast to all threads + fp8_output[offset] = clamped_max_fp8; // Broadcast to all threads } - Array clamped_vals; + Array clamped_vals; #pragma unroll - for (int i = 0; i < 4; ++i) { - float scaled_val = vec4[i] / clamped_max_converted; + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + float scaled_val = vec_in[i] / clamped_max_converted; clamped_vals[i] = clamp(scaled_val, -6.000000000e+00f, 6.000000000e+00f); } - Array<__e2m1, 4, 1> fp4_vals; - *reinterpret_cast*>(&fp4_vals[0]) = - __float2e2m1(*reinterpret_cast*>(&clamped_vals[0])); + Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals; + *reinterpret_cast*>( + &fp4_vals[0]) = + __float2e2m1( + *reinterpret_cast*>( + &clamped_vals[0])); - // Array<__e2m1, 4, 4> fp4_vals_aligned; #pragma unroll - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { output[i] = fp4_vals[i]; } } diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 4f3925dc89e..4f371c26181 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -219,15 +219,9 @@ TEST_F(BQTest, ScheduleAsPointwise) { fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); - // This is the 3D input to the BQ Op. - auto view_out_tv = quantization_results.block_scales->definition() - ->input(0) - ->as(); - for (auto t : {tv_data_hp, t0, - view_out_tv, quantization_results.quantized_tensor, quantization_results.block_scales, t_out}) { @@ -246,11 +240,7 @@ TEST_F(BQTest, ScheduleAsPointwise) { t->split(-3, 128); if (t != tv_data_hp) { - // Don't vectorize the outputs of reshape - if (t != view_out_tv && t != quantization_results.block_scales) { - t->axis(-1)->parallelize(ParallelType::Vectorize); - } - // I/512(BIDx), 128(TIDx), 1, 4(v) + t->axis(-1)->parallelize(ParallelType::Vectorize); t->axis(-3)->parallelize(ParallelType::TIDx); t->axis(-4)->parallelize(ParallelType::BIDx); } @@ -285,8 +275,8 @@ TEST_F(BQTest, ScheduleAsPointwise) { } // Basic shape checks - EXPECT_EQ(block_scales_output.dim(), 3); - EXPECT_EQ(quantized_tensor_output.dim(), 3); + EXPECT_EQ(block_scales_output.dim(), 2); + EXPECT_EQ(quantized_tensor_output.dim(), 2); } TEST_F(BQTest, ScheduleAsPointwise2D) { @@ -333,43 +323,26 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { t0->setMemoryType(MemoryType::Local); - // This is the 3D input to the BQ Op. - auto view_out_tv = quantization_results.block_scales->definition() - ->input(0) - ->as(); - - // split the input 2D tensor to make then 3D. - // (i0, i1) -> (i0, i1//block_size, block_size) - for (auto t : {tv_data_hp, t0}) { - t->split(-1, block_size); - } - for (auto t : {tv_data_hp, t0, - view_out_tv, quantization_results.quantized_tensor, quantization_results.block_scales, t_out}) { - // (m, n, k) -> (m, n*k) - t->merge(1, 2); - - // (m, n*k) -> (m, n*k/4, 4) - // (m, n*k/4, 4) -> (m, n*k/128, 32, 4) + // (m, n) -> (m, n/4, 4) + // (m, n/4, 4) -> (m, n/128, 32, 4) t->split(-1, 4); // V t->split(-2, 32); // BDx - // (m, n*k/128, 32, 4) -> (m, 1, n*k/128, 32, 4) - // (m, 1, n*k/128, 32, 4) -> (m/4, 4, 1, n*k/128, 32, 4) + // (m, n/128, 32, 4) -> (m, 1, n/128, 32, 4) + // (m, 1, n/128, 32, 4) -> (m/4, 4, 1, n/128, 32, 4) t->split(0, 1); t->split(0, 4); - // (m/4(bidy), 4(tidy), 1, n*k/128(bidx), 32(tidx), 49(v)) + // (m/4(bidy), 4(tidy), 1, n*k/128(bidx), 32(tidx), 4(v)) if (t != tv_data_hp) { // Don't vectorize the outputs of reshape - if (t != view_out_tv && t != quantization_results.block_scales) { - t->axis(-1)->parallelize(ParallelType::Vectorize); - } + t->axis(-1)->parallelize(ParallelType::Vectorize); t->axis(-2)->parallelize(ParallelType::TIDx); t->axis(-3)->parallelize(ParallelType::BIDx); t->axis(-5)->parallelize(ParallelType::TIDy); @@ -410,8 +383,8 @@ TEST_F(BQTest, ScheduleAsPointwise2D) { } // Basic shape checks - EXPECT_EQ(block_scales_output.dim(), 3); - EXPECT_EQ(quantized_tensor_output.dim(), 3); + EXPECT_EQ(block_scales_output.dim(), 2); + EXPECT_EQ(quantized_tensor_output.dim(), 2); } TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { From ab5e60e4d9e41f5fbd75a19c8129b4c1b9c8b3a0 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 07:29:24 -0700 Subject: [PATCH 13/79] move comments around --- csrc/ops/arith.cpp | 16 ---------------- csrc/ops/arith.h | 3 +++ 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index e3a929de0d7..8e9abe086a0 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2634,16 +2634,6 @@ TensorView* prefixSum(TensorView* tv, int64_t dim) { /*init=*/tv->fusion()->zeroVal(tv->dtype())); } -// API for block quantization to nvFP4. -// We take FP32 or BF16 input and produce two outputs -// nvFP4(x2) outputs and FP8 block scales. -// The input is first reshape to have an inner 16 dimension. -// This 16 should be configurable but fixed for now. -// So, if the input is of shape (m, k), it is first reshaped -// to (m, k/16, 16). The quantized output has the shape (m, k/16, 16) -// and block scales has the shape (m, k/16, b(1)), where the inner dimension -// had been "reduced" (max). Please note there is no actual reduction and this -// node should be handled by the pointwise scheduler. // Currently this node get lowered to a runtime function, which expects the // input in registers and write out quantized values or registers and block // scales to global memory. @@ -2702,12 +2692,6 @@ BlockQuantizationResults blockQuantize( for (size_t i = 0; i < inp_domain.size(); ++i) { if (i == inp_domain.size() - 1) { - // scales_out_domain.push_back( - // IterDomainBuilder( - // input->fusion()->zeroVal(), input->fusion()->oneVal()) - // .iter_type(IterType::Broadcast) - // .build()); - // Close inp_domain[i] and divide by 16 to create new iter domain scales_out_domain.push_back( IterDomainBuilder( diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 6d27674d199..6819908fb7f 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -835,6 +835,9 @@ struct BlockQuantizationResults { }; //! TODO: Expose global scaling factor +// API for block quantization to nvFP4. +// We take FP32 or BF16 input and produce two outputs +// nvFP4(x2) outputs and FP8 block scales. NVF_API BlockQuantizationResults blockQuantize( TensorView* input, int64_t block_size = 16, From 8eedb764d570ad68853a860b5b7d498ff296aa03 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 08:48:39 -0700 Subject: [PATCH 14/79] clean up --- csrc/device_lower/pass/index.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 3353fb445c2..639bba3fa1b 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -406,16 +406,13 @@ void IndexLowering::handle(const TopKOp* top) { } void IndexLowering::handle(const BlockQuantizationOp* bqop) { - // const auto in = lowerSrcIndex(bqop->in(), bqop->quantizedOutput()); const auto in = IrBuilder::create( bqop->in()->as(), bqop->fusion()->zeroVal()); const auto out_scales = IrBuilder::create( - bqop->blockScales()->as(), - bqop->fusion()->zeroVal()); // lowerDstIndex(bqop->blockScales()); + bqop->blockScales()->as(), bqop->fusion()->zeroVal()); const auto out_quantized = IrBuilder::create( - bqop->quantizedOutput()->as(), - bqop->fusion()->zeroVal()); // lowerDstIndex(bqop->quantizedOutput()); + bqop->quantizedOutput()->as(), bqop->fusion()->zeroVal()); pushBack( IrBuilder::create(out_scales, out_quantized, in)); From a7b6d58b526b7b6989ef255b26d9d0db3dc5abb2 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 08:54:09 -0700 Subject: [PATCH 15/79] edit comments --- csrc/codegen.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index cd2c1320db1..27d37e0fb68 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1762,17 +1762,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Special handling of BlockQuantizationOp to call the runtime function. // TODO: add support for global scaling factor - // TODO: make sure we can handle BF16 input void handle(const BlockQuantizationOp* bqop) final { - // Block quantization takes in a reshaped input - // so a (m, k) input becomes (m, k/16, 16) - // The 16 is the block size and is "reduced". - // For FP32 input each thread locally reduces 4 elements and a quad of - // thread reduces 16. For BF16 input each thread locally reduces 8 elements - // and a pair of thread reduces 16. The device function for block - // quantization. Due to these assumptions we have to have the inner - // dimension of the bqop input (and outputs) vectorized by 4 or 8 for the - // current device function to work. + // This operator is plumbed down to a runtime function call. + // One of the assumptions is that the device runtime expects + // 4 consecutive inputs (8 for FB16) per thread. We achieve this by having + // the input tv scheduler to have the inner dimension vectorized by 4/8. auto output = bqop->quantizedOutput()->as()->view(); int64_t vector_word_size = ir_utils::getVectorizeSize(output); From 75699a16613449c46a5aa6e8f6d34572d8f52230 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 09:16:59 -0700 Subject: [PATCH 16/79] remove setting parallel type for BIDx TIDx --- csrc/device_lower/analysis/sync_information.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index e5316a704b4..0a5563e47a6 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -292,6 +292,14 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { continue; } + if ((parallel_type == ParallelType::BIDx || + parallel_type == ParallelType::TIDx) && + consumer->definition() != nullptr && + consumer->definition()->isA()) { + // Skip BIDx and TIDx check for BlockQuantizationOp consumer + continue; + } + // In the case when the parallel id's are mapped by ca map, // will additionally need to consider if the producer is // a redundant write. The raw dim can be skipped only if @@ -499,12 +507,10 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { if (error_on_failure) { if (raw_dims.hasBID()) { NVF_ERROR( - // We need to allow the output of block quantization and the - // outputs of the reshape preceding the block - // quantization to be in local memory. - producer->getMemoryType() == MemoryType::Global || + producer->getMemoryType() == MemoryType::Global /*|| consumer->definition()->isA() || - consumer->uses()[0]->isA(), + consumer->uses()[0]->isA()*/ + , "Inconsistent parallelization found between T", producer->name(), " (", From c7f1d8df3b5a3c5f1a20ee143818e7bb8f4c82d8 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 09:18:14 -0700 Subject: [PATCH 17/79] remove setting parallel type for BIDx TIDx - cleanup --- csrc/device_lower/analysis/sync_information.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 0a5563e47a6..969f15084a1 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -507,10 +507,7 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { if (error_on_failure) { if (raw_dims.hasBID()) { NVF_ERROR( - producer->getMemoryType() == MemoryType::Global /*|| - consumer->definition()->isA() || - consumer->uses()[0]->isA()*/ - , + producer->getMemoryType() == MemoryType::Global, "Inconsistent parallelization found between T", producer->name(), " (", From 44740dc16e8df37bb691b52991df161faa662bd7 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 8 Oct 2025 13:54:04 -0700 Subject: [PATCH 18/79] adding support for parallel type group --- csrc/codegen.cpp | 15 ++++++++++++++- csrc/device_lower/validation.cpp | 2 +- tests/cpp/test_low_precision_recipe.cpp | 7 ++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 27d37e0fb68..b52e65b9020 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1768,7 +1768,20 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // 4 consecutive inputs (8 for FB16) per thread. We achieve this by having // the input tv scheduler to have the inner dimension vectorized by 4/8. auto output = bqop->quantizedOutput()->as()->view(); - int64_t vector_word_size = ir_utils::getVectorizeSize(output); + int64_t vector_word_size = 1; + + // Get the loop domain of the TensorView output and check for group/vector + // parallel types. This assumes that both parallel types aren't present. + const auto& loop_domain = output->getLoopDomain(); + for (auto* domain : loop_domain) { + auto parallel_type = domain->getParallelType(); + if (parallel_type == ParallelType::Group || + parallel_type == ParallelType::Vectorize) { + if (domain->extent()->isConstInt()) { + vector_word_size = domain->extent()->evaluate().as(); + } + } + } auto input_dtype = bqop->in()->as()->view()->getDataType(); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index f1674abff68..2911be41b95 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1400,7 +1400,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { NVF_CHECK( def->isA() || def->isA() || def->isA() || def->isA() || - def->isA(), + def->isA() || def->isA(), "Invalid use of ParallelType::Group: ", def->toString()); diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 4f371c26181..2b6738fad73 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -240,7 +240,12 @@ TEST_F(BQTest, ScheduleAsPointwise) { t->split(-3, 128); if (t != tv_data_hp) { - t->axis(-1)->parallelize(ParallelType::Vectorize); + if (t == quantization_results.block_scales || + t == quantization_results.quantized_tensor) { + t->axis(-1)->parallelize(ParallelType::Group); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } t->axis(-3)->parallelize(ParallelType::TIDx); t->axis(-4)->parallelize(ParallelType::BIDx); } From 91304350f83a4f01e4971c03d29169ca389c030e Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 10 Oct 2025 06:46:48 -0700 Subject: [PATCH 19/79] address reviewer comments --- .../analysis/sync_information.cpp | 5 ++-- csrc/device_lower/validation.cpp | 2 -- csrc/ir/nodes.cpp | 8 ++++++ csrc/ir/utils.cpp | 6 +++++ csrc/ir/utils.h | 3 +++ csrc/logical_domain_map.cpp | 11 ++++---- csrc/ops/arith.cpp | 27 ++++++------------- tests/cpp/test_low_precision_recipe.cpp | 15 ++++------- 8 files changed, 39 insertions(+), 38 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 969f15084a1..b8bc47c4fce 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -294,9 +294,10 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { if ((parallel_type == ParallelType::BIDx || parallel_type == ParallelType::TIDx) && - consumer->definition() != nullptr && - consumer->definition()->isA()) { + ir_utils::isBlockScalingFactor(consumer)) { // Skip BIDx and TIDx check for BlockQuantizationOp consumer + producer->as()->printTransforms(); + consumer->as()->printTransforms(); continue; } diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 2911be41b95..d66493bbcf8 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -47,8 +47,6 @@ class ValidateSiblings : public IterVisitor { using IterVisitor::handle; void dispatch(Expr* expr) final { - // Skip BlockQuantization. - // It has sibling outputs which differ from each other if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2 || !ir_utils::hasUniformSiblings(expr)) { IterVisitor::dispatch(expr); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 7641db28788..a012bb296b8 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -6298,6 +6298,14 @@ std::vector PreprocessGroupedMatmulInputSf::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(PreprocessGroupedMatmulInputSf) +// Details: +// Currently output_scales is the first input in the constructor even though +// it's the second output. This is because if it's the second output then we hit +// a bug in indexing. The stack trace can be seen here: +// https://gist.github.com/protonu/dc35024c1291625b2b7ce87baa39e2ae +// This happens when creating UnswitchPredicate, probably in the call to +// TensorIndexer::getPredicates. The incorrect predicate_domains for the tv +// in the call to getPredicateDomains. BlockQuantizationOp::BlockQuantizationOp( IrBuilderPasskey passkey, Val* output_scales, diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index ec254316382..0fd273ac81a 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1759,4 +1759,10 @@ bool isParallelizedBy(const std::vector& ids, ParallelType pt) { ids, [&](IterDomain* id) { return id->getParallelType() == pt; }); } +bool isBlockScalingFactor(const TensorView* tv) { + return tv->definition() != nullptr && + tv->definition()->isA() && + tv == tv->definition()->as()->blockScales(); +} + } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 0c436e7b9c7..986613bd2c8 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -865,4 +865,7 @@ std::vector propagateScatterAllocationDomain( bool isParallelizedBy(const std::vector& ids, ParallelType pt); +// Check if tv is the block scales output of a BlockQuantization Op. +bool isBlockScalingFactor(const TensorView* tv); + } // namespace nvfuser::ir_utils diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index d1e45f8d42b..9fcb9ac4fe0 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -154,13 +154,14 @@ std::pair, bool> getNonMappingDomainInfo( } has_consumer_id = true; } - } else if ( - auto bqop = - dynamic_cast(consumer_tv->definition())) { - if (producer_tv == bqop->in()) { + } else if (dynamic_cast(consumer_tv->definition())) { + // We don't map the inner-most dimension of the block scaling factors + // as it's extent is reduced by a factor of the block size + // for example [i0, i1] => [i0, i1/16] where 16 is the block size. + if (ir_utils::isBlockScalingFactor(consumer_tv)) { auto producer_logical = TensorDomain::noReductions(producer_tv->getLogicalDomain()); - auto last_logical_dim = producer_tv->getLogicalDomain().size() - 1; + auto last_logical_dim = producer_logical.size() - 1; non_mapping_ids.insert(producer_logical.at(last_logical_dim)); // We are mapping everything but the last ID. has_consumer_id = true; diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 8e9abe086a0..9526bda0f75 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2651,23 +2651,15 @@ BlockQuantizationResults blockQuantize( "Currently only output data type of Float4_e2m1fn_x2 is supported"); // Validate input data type - // WE'll only support FP32 or BF16 - // TODO: BF16 + // We'll only support FP32 or BF16 + // We should check if the inputs are FP or BF16. NVF_CHECK( - isFloatingPointType(input->getDataType().value()), + input->getDataType().value() == DataType::Float || + input->getDataType().value() == DataType::BFloat16, "Block quantization expects floating point input but got ", input->getDataType().value()); - // We reshape the input to the keep the block size as the inner dimension - // that will be "reduced". For example, if our input in [m, k] and the block - // size is 16. Then we need to compute the max of the inner-most 16 elements. - // Thus, we first reshape the input to [m, k/16, 16] and then compute the max - // over the inner-most 16 elements. - // auto reshaped_input = reshape(input, [](auto& x) { x.split(-1, 16); }); - - auto inp_domain = - // TensorDomain::noReductions(reshaped_input->getLogicalDomain()); - TensorDomain::noReductions(input->getLogicalDomain()); + auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain()); // Validate input tensor is not zero-dimensional NVF_CHECK( @@ -2683,22 +2675,19 @@ BlockQuantizationResults blockQuantize( } // Create output domain for block scales - // If the input after reshape is [m, k/16, 16] then - // block scales will be [m, k/16, b(1)] - // We keep the inner-dimension as b(1) to make scheduling easier. - // Both the ouputs of quantization can be scheduled in similar fashion. + // We'll clone the outer domains but divide the + // extent of the inner domain by 16. (block_size). std::vector scales_out_domain; scales_out_domain.reserve(inp_domain.size()); for (size_t i = 0; i < inp_domain.size(); ++i) { if (i == inp_domain.size() - 1) { - // Close inp_domain[i] and divide by 16 to create new iter domain scales_out_domain.push_back( IterDomainBuilder( inp_domain[i]->start(), SimplifyingIrBuilder::divExpr( inp_domain[i]->extent(), - IrBuilder::create(16L, DataType::Index))) + IrBuilder::create(block_size, DataType::Index))) .build()); } else { diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 2b6738fad73..a7bc010778a 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -10,8 +10,6 @@ #include #include -#include -#include #include #include @@ -175,10 +173,10 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { HeuristicIs(SchedulerType::InnerPersistent))); } -class BQTest : public BlackwellBase {}; +class BlockQuantizationTest : public BlackwellBase {}; -TEST_F(BQTest, ScheduleAsPointwise) { - // Basic test implementation +TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { + // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); createNVFP4QunatizationFusion(fusion.get(), DataType::Float); @@ -191,7 +189,6 @@ TEST_F(BQTest, ScheduleAsPointwise) { inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); auto outputs_baseline = fec.runFusionWithInputs(inputs); - // Print baseline outputs auto baseline_block_scales = outputs_baseline[0].as(); auto baseline_quantized_tensor = outputs_baseline[1].as(); @@ -199,7 +196,6 @@ TEST_F(BQTest, ScheduleAsPointwise) { auto baseline_block_scales_cpu = baseline_block_scales.cpu(); auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); - // Print first 32 bytes of baseline block_scales_output in hex format const uint8_t* baseline_block_scales_data = static_cast(baseline_block_scales_cpu.data_ptr()); const uint8_t* baseline_quantized_data = @@ -211,7 +207,6 @@ TEST_F(BQTest, ScheduleAsPointwise) { auto tv_data_hp = makeContigTensor(2, DataType::Float); fusion_new_op->addInput(tv_data_hp); - // t0 is 2D auto t0 = set(tv_data_hp); auto quantization_results = blockQuantize(t0); auto t_out = set(quantization_results.quantized_tensor); @@ -284,8 +279,8 @@ TEST_F(BQTest, ScheduleAsPointwise) { EXPECT_EQ(quantized_tensor_output.dim(), 2); } -TEST_F(BQTest, ScheduleAsPointwise2D) { - // Basic test implementation +TEST_F(BlockQuantizationTest, ScheduleAsPointwise2D) { + // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); createNVFP4QunatizationFusion(fusion.get(), DataType::Float); From 434a1b170d1e6898e95b22a6e3daa2ac21057f4c Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 10 Oct 2025 07:12:58 -0700 Subject: [PATCH 20/79] adding a comment --- csrc/device_lower/analysis/sync_information.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index b8bc47c4fce..804ca2498f6 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -292,12 +292,14 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { continue; } + // Skip BIDx and TIDx check for block scaling factor output of + // BlockQuantizationOp. The inner-most dimension of this output + // does not map to any producer ID and is used to generate BIDx and + // TIDx. Since this Op is codegen'd to a runtime fuction, any + // sync/predication is handled there. if ((parallel_type == ParallelType::BIDx || parallel_type == ParallelType::TIDx) && ir_utils::isBlockScalingFactor(consumer)) { - // Skip BIDx and TIDx check for BlockQuantizationOp consumer - producer->as()->printTransforms(); - consumer->as()->printTransforms(); continue; } From ce0b820766b623384d8180f5b7f743b3f615c9ce Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 10 Oct 2025 07:20:15 -0700 Subject: [PATCH 21/79] removing a comment --- csrc/device_lower/validation.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index d66493bbcf8..340be59336c 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1019,9 +1019,6 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { def->as()->serialGridReductionRequested()) || (def->isA() && def->as()->getUnaryOpType() == UnaryOpType::Cast) || - // This throws without this check. - // Maybe I shouldn't vectorize the outputs of the - // BlockQuantizationOp def->isA(), "Vectorized accesses cannot be inline with computation: ", (def == nullptr ? tv->toString() : def->toString())); From fb4ad217117d43901edf5512489f3f5072da0518 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 10 Oct 2025 10:05:37 -0700 Subject: [PATCH 22/79] modifying a check --- csrc/device_lower/analysis/non_divisible_split.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index 5ca07d0f57e..7ea1ca60451 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -214,7 +214,7 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { auto def = tv->definition(); // A block quantization is plumbed down to a runtime function. // The runtime function handles predication so skip this. - if (def == nullptr || def->isA()) { + if (def == nullptr || ir_utils::isBlockScalingFactor(tv)) { continue; } From beb2f06575f8d27ca9f4a604485a72140e45cc91 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 10 Oct 2025 10:33:41 -0700 Subject: [PATCH 23/79] merge --- csrc/device_lower/analysis/non_divisible_split.cpp | 4 +++- csrc/runtime/compiled_kernel.cpp | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index 7ea1ca60451..50aa9188211 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -213,7 +213,9 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { auto def = tv->definition(); // A block quantization is plumbed down to a runtime function. - // The runtime function handles predication so skip this. + // The runtime function handles predication so skip this for the + // block scales. That's because the inner dimension of that tv is not + // mapped to any ID of the input or sibling output. if (def == nullptr || ir_utils::isBlockScalingFactor(tv)) { continue; } diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index d66c9fbd19f..0fa2719886d 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -1457,11 +1457,8 @@ std::string CompiledKernel::getStructuredCode() const { kernel()->summary().has_topk, kernel()->summary().has_scan, kernel()->summary().has_preprocess_grouped_matmul_input_sf, -<<<<<<< HEAD + kernel()->summary().has_cluster_reduction, kernel()->summary().has_block_quantize_op); -======= - kernel()->summary().has_cluster_reduction); ->>>>>>> main } std::string CompiledKernel::disassembledKernelSASS() const { From b04d9741e4c3827c493e7b1e7e9232f8fb9e74dc Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 07:02:57 -0700 Subject: [PATCH 24/79] add validation, update test --- csrc/device_lower/validation.cpp | 220 ++++++++++++++++++++++++ runtime/block_quantization_kernels.cu | 19 +- tests/cpp/test_low_precision_recipe.cpp | 3 + 3 files changed, 235 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c1a833905d1..1c292d5f1ac 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -411,6 +411,226 @@ class ExprValidator : public OptOutDispatch { (*it)->toString()); } } + + // Given a set of loop domain IterDomains, find their logical domain origins + std::vector findLogicalDomainOrigins( + const std::vector& loop_domain_ids, + const TensorView* tv) { + // Get the logical domain to use as the target/boundary + const auto& logical_domain = tv->getLogicalDomain(); + + // Use IterVisitor to find inputs to the loop domain IDs, + // bounded by the logical domain + std::vector inputs_as_vals = IterVisitor::getInputsTo( + {loop_domain_ids.begin(), loop_domain_ids.end()}, + {logical_domain.begin(), logical_domain.end()}); + + // Convert back to IterDomains + std::vector logical_origins; + for (auto val : inputs_as_vals) { + logical_origins.push_back(val->as()); + } + + return logical_origins; + } + + // I'd like to check that the inner dimension of the input + // is divisble by 16. + void handle(BlockQuantizationOp* bqop) final { + auto inp_tv = bqop->input(0)->as(); + auto quantized_output = bqop->quantizedOutput()->as(); + auto block_scaling_factor = bqop->blockScales()->as(); + + NVF_ERROR_EQ( + inp_tv->getMemoryType(), + MemoryType::Local, + "Input must be a local memory tensor. Found: ", + inp_tv->getMemoryType()); + + NVF_ERROR_EQ( + quantized_output->getMemoryType(), + MemoryType::Local, + "Quantized output must be a local memory tensor. Found: ", + quantized_output->getMemoryType()); + + NVF_ERROR_EQ( + block_scaling_factor->getMemoryType(), + MemoryType::Global, + "Block scaling factor must be a global memory tensor. Found: ", + block_scaling_factor->getMemoryType()); + + // outputs have the same allocation domain + // as the loop domain. This has to be later + // relaxed for the scaling factors. + NVF_ERROR( + quantized_output->hasAllocation() == false, + "Quantized output must not have an allocation domain."); + NVF_ERROR( + block_scaling_factor->hasAllocation() == false, + "Block scaling factor must not have an allocation domain."); + + // Check that it either had vectorized ID or grouped ID + // not both and the extent is either 4(FP32) or 8(BF16) + IterDomain* grouped_or_vector_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_y = nullptr; + IterDomain* block_y = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : block_scaling_factor->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group || + loop_id->getParallelType() == ParallelType::Vectorize) { + NVF_ERROR( + grouped_or_vector_id == nullptr, + "Multiple IDs found to be grouped/vectorized"); + grouped_or_vector_id = loop_id; + } + } + + auto parallel_domains_map = + ir_utils::getParallelDomains(block_scaling_factor); + if (parallel_domains_map.find(ParallelType::TIDx) != + parallel_domains_map.end()) { + thread_x = parallel_domains_map.at(ParallelType::TIDx); + } + if (parallel_domains_map.find(ParallelType::BIDx) != + parallel_domains_map.end()) { + block_x = parallel_domains_map.at(ParallelType::BIDx); + } + if (parallel_domains_map.find(ParallelType::TIDy) != + parallel_domains_map.end()) { + thread_y = parallel_domains_map.at(ParallelType::TIDy); + } + if (parallel_domains_map.find(ParallelType::BIDy) != + parallel_domains_map.end()) { + block_y = parallel_domains_map.at(ParallelType::BIDy); + } + if (parallel_domains_map.find(ParallelType::TIDz) != + parallel_domains_map.end()) { + thread_z = parallel_domains_map.at(ParallelType::TIDz); + } + if (parallel_domains_map.find(ParallelType::BIDz) != + parallel_domains_map.end()) { + block_z = parallel_domains_map.at(ParallelType::BIDz); + } + + NVF_ERROR( + grouped_or_vector_id != nullptr, + "One of the output IDs must be grouped or vectorized for " + "BlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using BlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for " + "BlockQuantizationOp: ", + bqop->toString()); + + bool is_2d_scheduled = + (thread_y != nullptr || block_y != nullptr) ? true : false; + + auto inner_extent = + grouped_or_vector_id->extent()->evaluate().as(); + auto input_dtype = inp_tv->dtype(); + + NVF_ERROR( + (inner_extent == 4 && input_dtype == DataType::Float) || + (inner_extent == 8 && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The vectorized/grouped dimension must be 4 (FP32) or 8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + bqop->toString()); + + // Find the logical domain IDs that correspond to these loop IDs. + // Then we check that the logical domain IDs are the inner-most + // IDs. + auto input_logical_domains_ids = findLogicalDomainOrigins( + {grouped_or_vector_id, thread_x, block_x}, block_scaling_factor); + + // Get the size of input logical domains + size_t num_input_logical_domains = input_logical_domains_ids.size(); + + // Get the same number of elements from the innermost logical domain + const auto& logical_domain = block_scaling_factor->getLogicalDomain(); + std::vector innermost_logical_domains; + + // Extract from the rightmost (innermost) positions + for (int64_t i = logical_domain.size() - 1; + i >= 0 && innermost_logical_domains.size() < num_input_logical_domains; + i--) { + auto logical_id = logical_domain[i]; + if (!logical_id->isReduction() && !logical_id->isBroadcast()) { + innermost_logical_domains.insert( + innermost_logical_domains.begin(), logical_id); + } + } + + // Validate that input_logical_domains_ids and innermost_logical_domains + // contain the same IterDomains + std::unordered_set input_logical_set( + input_logical_domains_ids.begin(), input_logical_domains_ids.end()); + std::unordered_set innermost_logical_set( + innermost_logical_domains.begin(), innermost_logical_domains.end()); + + NVF_ERROR( + input_logical_set == innermost_logical_set, + "Input logical domain IDs do not match the innermost logical domains " + "for BlockQuantizationOp: ", + bqop->toString(), + ". Expected innermost domains: ", + toDelimitedString(innermost_logical_domains), + ". Found input logical domains: ", + toDelimitedString(input_logical_domains_ids)); + + // If it's 2D scheduled, the we get the IDs from the logical domain + // that correspond to blockIdx.y and threadIdx.y. We make sure the + // IDs from the logical domain don't share any ID with those from the + // thread/block for x-dimension was derived. + if (is_2d_scheduled) { + std::vector input_logical_domains_ids_2d = {}; + for (auto id : {thread_y, block_y}) { + if (id) { + input_logical_domains_ids_2d.push_back(id); + } + } + + auto input_logical_domains_ids_y = findLogicalDomainOrigins( + input_logical_domains_ids_2d, block_scaling_factor); + + // Validate that input_logical_domains_ids and input_logical_domains_ids_y + // don't have any elements in common + std::unordered_set input_logical_set_x( + input_logical_domains_ids.begin(), input_logical_domains_ids.end()); + std::unordered_set input_logical_set_y( + input_logical_domains_ids_y.begin(), + input_logical_domains_ids_y.end()); + + for (const auto& id : input_logical_set_x) { + NVF_ERROR( + input_logical_set_y.find(id) == input_logical_set_y.end(), + "Input logical domain IDs for X and Y dimensions have overlapping " + "elements " + "for BlockQuantizationOp: ", + bqop->toString(), + ". Overlapping IterDomain: ", + id->toString(), + ". X logical domains: ", + toDelimitedString(input_logical_domains_ids), + ". Y logical domains: ", + toDelimitedString(input_logical_domains_ids_y)); + } + } + } }; } // namespace diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index e3de211fb45..92c811c9364 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -49,12 +49,20 @@ __device__ void block_quantize_to_nvfp4( Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output) { - assert(blockDim.x % 4 == 0); + if constexpr (std::is_same::value) { + assert(blockDim.x % 4 == 0); + } else if constexpr (std::is_same::value) { + assert(blockDim.x % 2 == 0); + } assert(blockDim.z == 1 && gridDim.z == 1); static_assert( - ITEMS_PER_THREAD % 4 == 0, "ITEMS_PER_THREAD must be multiple of 4"); + (std::is_same::value && ITEMS_PER_THREAD == 4) || + (std::is_same::value && ITEMS_PER_THREAD == 8), + "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat type"); + + int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; - Array vec_in; + Array vec_in; vec_in.set(0.0f); // Initialize to zero like nvfuser does for (auto i = 0; i < ITEMS_PER_THREAD; i++) { @@ -95,11 +103,8 @@ __device__ void block_quantize_to_nvfp4( int offset_dim_y = threadIdx.y * blockDim.x * gridDim.x; int offset_into_block = blockIdx.x * blockDim.x + threadIdx.x; - int offset = (offset_y_blocks + offset_dim_y + offset_into_block) / 4; - // Convert back from FP8 to float using __e4m32float - if (threadIdx.x % ITEMS_PER_THREAD == 0) // Only one thread per quad writes - { + if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { fp8_output[offset] = clamped_max_fp8; // Broadcast to all threads } diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index a7bc010778a..0005d8d906b 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -115,6 +115,9 @@ void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { auto tv_data_hp_reshaped = reshape(tv_data_hp, [](auto& x) { x.split(-1, block_size); }); + // cast it to FP32 + tv_data_hp_reshaped = castOp(DataType::Float, tv_data_hp_reshaped); + auto tv_data_hp_abs = abs(tv_data_hp_reshaped); auto tv_data_hp_amax = max(tv_data_hp_abs, {-1}); // These scales are currently in fp32, we are going to `quantize` them to e4m3 From 7c79c32725c7b2ac6b0206ef33616f6f253993c0 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 08:23:35 -0700 Subject: [PATCH 25/79] clean up from merge --- runtime/block_quantization_kernels.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 92c811c9364..1486d294732 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -103,6 +103,9 @@ __device__ void block_quantize_to_nvfp4( int offset_dim_y = threadIdx.y * blockDim.x * gridDim.x; int offset_into_block = blockIdx.x * blockDim.x + threadIdx.x; + int offset = (offset_y_blocks + offset_dim_y + offset_into_block) / + THREADS_PER_SCALING_FACTOR; + // Convert back from FP8 to float using __e4m32float if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { fp8_output[offset] = clamped_max_fp8; // Broadcast to all threads From 8564431164716d0b7d1ed2bd3093b10cb5a1ac85 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 08:53:28 -0700 Subject: [PATCH 26/79] remove a utility function --- csrc/device_lower/analysis/non_divisible_split.cpp | 4 +++- csrc/device_lower/analysis/sync_information.cpp | 6 +++++- csrc/ir/utils.cpp | 6 ------ csrc/ir/utils.h | 3 --- csrc/logical_domain_map.cpp | 3 ++- csrc/ops/arith.cpp | 3 ++- 6 files changed, 12 insertions(+), 13 deletions(-) diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index 50aa9188211..0effe1618e1 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -216,7 +216,9 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { // The runtime function handles predication so skip this for the // block scales. That's because the inner dimension of that tv is not // mapped to any ID of the input or sibling output. - if (def == nullptr || ir_utils::isBlockScalingFactor(tv)) { + if (def == nullptr || + (tv->definition()->isA() && + tv == tv->definition()->as()->blockScales())) { continue; } diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 804ca2498f6..178408860e2 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -299,7 +299,11 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { // sync/predication is handled there. if ((parallel_type == ParallelType::BIDx || parallel_type == ParallelType::TIDx) && - ir_utils::isBlockScalingFactor(consumer)) { + (consumer->definition()->isA() && + consumer == + consumer->definition() + ->as() + ->blockScales())) { continue; } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index f4089a7c9ee..5b312f96700 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1747,10 +1747,4 @@ bool isParallelizedBy(const std::vector& ids, ParallelType pt) { ids, [&](IterDomain* id) { return id->getParallelType() == pt; }); } -bool isBlockScalingFactor(const TensorView* tv) { - return tv->definition() != nullptr && - tv->definition()->isA() && - tv == tv->definition()->as()->blockScales(); -} - } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 42ba9201b60..13af07ff0dc 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -861,7 +861,4 @@ std::vector propagateScatterAllocationDomain( bool isParallelizedBy(const std::vector& ids, ParallelType pt); -// Check if tv is the block scales output of a BlockQuantization Op. -bool isBlockScalingFactor(const TensorView* tv); - } // namespace nvfuser::ir_utils diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 9fcb9ac4fe0..b6c894f3b5f 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -158,7 +158,8 @@ std::pair, bool> getNonMappingDomainInfo( // We don't map the inner-most dimension of the block scaling factors // as it's extent is reduced by a factor of the block size // for example [i0, i1] => [i0, i1/16] where 16 is the block size. - if (ir_utils::isBlockScalingFactor(consumer_tv)) { + if (consumer_tv == + consumer_tv->definition()->as()->blockScales()) { auto producer_logical = TensorDomain::noReductions(producer_tv->getLogicalDomain()); auto last_logical_dim = producer_logical.size() - 1; diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 1f8fe7796e1..90b1e73fd34 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2641,7 +2641,8 @@ BlockQuantizationResults blockQuantize( // We should check if the inputs are FP or BF16. NVF_CHECK( input->getDataType().value() == DataType::Float || - input->getDataType().value() == DataType::BFloat16, + input->getDataType().value() == DataType::BFloat16 || + input->getDataType().value() == DataType::Half, "Block quantization expects floating point input but got ", input->getDataType().value()); From 774e27d18432f28a73e42a665ee5fa2e1955b54c Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 09:32:01 -0700 Subject: [PATCH 27/79] support half and bfloat in tests --- csrc/codegen.cpp | 2 +- runtime/block_quantization_kernels.cu | 25 ++++++++------ tests/cpp/test_low_precision_recipe.cpp | 43 +++++++++++++++++-------- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 167a07f029b..0693b5a3f18 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1837,7 +1837,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto input_dtype = bqop->in()->as()->view()->getDataType(); - if (input_dtype == DataType::BFloat16) { + if (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half) { NVF_ERROR( vector_word_size == 8, "Vectorization size should be 8 for " diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 1486d294732..f67ae04bf44 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -49,16 +49,25 @@ __device__ void block_quantize_to_nvfp4( Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output) { - if constexpr (std::is_same::value) { + constexpr bool is_half_or_bfloat = + std::is_same::value || std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert( + is_float || is_half_or_bfloat, + "Input type must be float, __half or __bfloat"); + + if constexpr (is_float) { assert(blockDim.x % 4 == 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (is_half_or_bfloat) { assert(blockDim.x % 2 == 0); } assert(blockDim.z == 1 && gridDim.z == 1); + static_assert( - (std::is_same::value && ITEMS_PER_THREAD == 4) || - (std::is_same::value && ITEMS_PER_THREAD == 8), - "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat type"); + (is_float && ITEMS_PER_THREAD == 4) || + (is_half_or_bfloat && ITEMS_PER_THREAD == 8), + "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat or __half " + "type"); int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; @@ -70,10 +79,8 @@ __device__ void block_quantize_to_nvfp4( vec_in[i] = input[i]; } else if constexpr (std::is_same::value) { vec_in[i] = __bfloat2float(input[i]); - } else { - static_assert( - std::is_same::value || std::is_same::value, - "Unsupported type"); + } else if constexpr (std::is_same::value) { + vec_in[i] = __half2float(input[i]); } } diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 0005d8d906b..34a601b6298 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -176,20 +176,24 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { HeuristicIs(SchedulerType::InnerPersistent))); } -class BlockQuantizationTest : public BlackwellBase {}; +class BlockQuantizationTest : public BlackwellBase, + public ::testing::WithParamInterface {}; + +TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { + auto data_hp_dtype = GetParam(); -TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), DataType::Float); + createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); const int m = 1024; const int n = 1024; std::vector inputs; - inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)) + .to(data_type_to_aten(data_hp_dtype))); auto outputs_baseline = fec.runFusionWithInputs(inputs); auto baseline_block_scales = outputs_baseline[0].as(); @@ -207,7 +211,7 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { std::unique_ptr fusion_new_op = std::make_unique(); FusionGuard fg2(fusion_new_op.get()); - auto tv_data_hp = makeContigTensor(2, DataType::Float); + auto tv_data_hp = makeContigTensor(2, data_hp_dtype); fusion_new_op->addInput(tv_data_hp); auto t0 = set(tv_data_hp); @@ -217,6 +221,8 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); + auto vectorization_factor = data_hp_dtype == DataType::Float ? 4 : 8; + for (auto t : {tv_data_hp, t0, @@ -229,9 +235,9 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { t->merge(-2); } - // split by 4. + // split by 4 (or 8). // I -> I/4, 4 - t->split(-1, 4); + t->split(-1, vectorization_factor); // I//4, 4 -> I/4, 1, 4 t->split(-2, 1); // I//4, 1, 4 -> I/512, 128, 1, 4 @@ -282,18 +288,21 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise) { EXPECT_EQ(quantized_tensor_output.dim(), 2); } -TEST_F(BlockQuantizationTest, ScheduleAsPointwise2D) { +TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { + auto data_hp_dtype = GetParam(); + // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), DataType::Float); + createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); const int m = 1024; const int n = 1024; std::vector inputs; - inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat))); + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)) + .to(data_type_to_aten(data_hp_dtype))); auto outputs_baseline = fec.runFusionWithInputs(inputs); // Print baseline outputs @@ -312,7 +321,7 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise2D) { std::unique_ptr fusion_new_op = std::make_unique(); FusionGuard fg2(fusion_new_op.get()); - auto tv_data_hp = makeContigTensor(2, DataType::Float); + auto tv_data_hp = makeContigTensor(2, data_hp_dtype); fusion_new_op->addInput(tv_data_hp); // t0 is 2D @@ -326,15 +335,17 @@ TEST_F(BlockQuantizationTest, ScheduleAsPointwise2D) { t0->setMemoryType(MemoryType::Local); + auto vectorization_factor = data_hp_dtype == DataType::Float ? 4 : 8; + for (auto t : {tv_data_hp, t0, quantization_results.quantized_tensor, quantization_results.block_scales, t_out}) { - // (m, n) -> (m, n/4, 4) + // (m, n) -> (m, n/4, 4) (or (m, n/8, 8) if bfloat16) // (m, n/4, 4) -> (m, n/128, 32, 4) - t->split(-1, 4); // V + t->split(-1, vectorization_factor); // V t->split(-2, 32); // BDx // (m, n/128, 32, 4) -> (m, 1, n/128, 32, 4) @@ -538,4 +549,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DataType::BFloat16, DataType::Float), testing::PrintToStringParamName()); +INSTANTIATE_TEST_SUITE_P( + , + BlockQuantizationTest, + ::testing::Values(DataType::BFloat16, DataType::Float, DataType::Half), + testing::PrintToStringParamName()); + } // namespace nvfuser From d43096ed476428d9d2ac18d2e897c6c68f7d8f80 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 11:24:49 -0700 Subject: [PATCH 28/79] removing vectorize --- csrc/codegen.cpp | 17 ++++++++--------- csrc/device_lower/pass/predicate.cpp | 6 +----- csrc/device_lower/validation.cpp | 19 +++++++++---------- tests/cpp/test_low_precision_recipe.cpp | 8 ++++++-- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 0693b5a3f18..7fd1ac2cb54 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1819,17 +1819,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // 4 consecutive inputs (8 for FB16) per thread. We achieve this by having // the input tv scheduler to have the inner dimension vectorized by 4/8. auto output = bqop->quantizedOutput()->as()->view(); - int64_t vector_word_size = 1; + int64_t group_size = 1; // Get the loop domain of the TensorView output and check for group/vector // parallel types. This assumes that both parallel types aren't present. const auto& loop_domain = output->getLoopDomain(); for (auto* domain : loop_domain) { auto parallel_type = domain->getParallelType(); - if (parallel_type == ParallelType::Group || - parallel_type == ParallelType::Vectorize) { + if (parallel_type == ParallelType::Group) { if (domain->extent()->isConstInt()) { - vector_word_size = domain->extent()->evaluate().as(); + group_size = domain->extent()->evaluate().as(); } } } @@ -1839,21 +1838,21 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { if (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half) { NVF_ERROR( - vector_word_size == 8, - "Vectorization size should be 8 for " + group_size == 8, + "Group size should be 8 for " "BlockQuantizationOp: ", bqop->toString()); } else { NVF_ERROR( - vector_word_size == 4, - "Vectorization size should be 4 for " + group_size == 4, + "Group size should be 4 for " "BlockQuantizationOp: ", bqop->toString()); } ArgumentBuilder template_args; - template_args.arg(vector_word_size); // ITEMS_PER_THREAD + template_args.arg(group_size); // ITEMS_PER_THREAD // Function arguments ArgumentBuilder func_args; diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index d2cc8189c31..664a63f4f56 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -67,11 +67,7 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { auto vec_expr = ite->thenBody()[0]; NVF_ERROR( vec_expr->isA() || vec_expr->isA() || - vec_expr->isA() || - vec_expr->isA() || - // To supress the throw. - // I think this is predicated on the vectorized dim. - vec_expr->isA(), + vec_expr->isA() || vec_expr->isA(), "Vectorize predicate exprs only supported on set operations."); NVF_ERROR( ir_utils::isTvOp(vec_expr), diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 1c292d5f1ac..312fe57ed78 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -471,7 +471,7 @@ class ExprValidator : public OptOutDispatch { // Check that it either had vectorized ID or grouped ID // not both and the extent is either 4(FP32) or 8(BF16) - IterDomain* grouped_or_vector_id = nullptr; + IterDomain* grouped_id = nullptr; IterDomain* thread_x = nullptr; IterDomain* block_x = nullptr; IterDomain* thread_y = nullptr; @@ -480,17 +480,17 @@ class ExprValidator : public OptOutDispatch { IterDomain* block_z = nullptr; for (const auto& loop_id : block_scaling_factor->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group || - loop_id->getParallelType() == ParallelType::Vectorize) { + if (loop_id->getParallelType() == ParallelType::Group) { NVF_ERROR( - grouped_or_vector_id == nullptr, + grouped_id == nullptr, "Multiple IDs found to be grouped/vectorized"); - grouped_or_vector_id = loop_id; + grouped_id = loop_id; } } auto parallel_domains_map = ir_utils::getParallelDomains(block_scaling_factor); + if (parallel_domains_map.find(ParallelType::TIDx) != parallel_domains_map.end()) { thread_x = parallel_domains_map.at(ParallelType::TIDx); @@ -517,8 +517,8 @@ class ExprValidator : public OptOutDispatch { } NVF_ERROR( - grouped_or_vector_id != nullptr, - "One of the output IDs must be grouped or vectorized for " + grouped_id != nullptr, + "One of the output IDs must be grouped for " "BlockQuantizationOp: ", bqop->toString()); @@ -536,8 +536,7 @@ class ExprValidator : public OptOutDispatch { bool is_2d_scheduled = (thread_y != nullptr || block_y != nullptr) ? true : false; - auto inner_extent = - grouped_or_vector_id->extent()->evaluate().as(); + auto inner_extent = grouped_id->extent()->evaluate().as(); auto input_dtype = inp_tv->dtype(); NVF_ERROR( @@ -555,7 +554,7 @@ class ExprValidator : public OptOutDispatch { // Then we check that the logical domain IDs are the inner-most // IDs. auto input_logical_domains_ids = findLogicalDomainOrigins( - {grouped_or_vector_id, thread_x, block_x}, block_scaling_factor); + {grouped_id, thread_x, block_x}, block_scaling_factor); // Get the size of input logical domains size_t num_input_logical_domains = input_logical_domains_ids.size(); diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 34a601b6298..260a8e0e3f4 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -355,8 +355,12 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { // (m/4(bidy), 4(tidy), 1, n*k/128(bidx), 32(tidx), 4(v)) if (t != tv_data_hp) { - // Don't vectorize the outputs of reshape - t->axis(-1)->parallelize(ParallelType::Vectorize); + if (t == quantization_results.block_scales || + t == quantization_results.quantized_tensor) { + t->axis(-1)->parallelize(ParallelType::Group); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } t->axis(-2)->parallelize(ParallelType::TIDx); t->axis(-3)->parallelize(ParallelType::BIDx); t->axis(-5)->parallelize(ParallelType::TIDy); From 44d192c5bef6cf1e1c0fdf667534617f0de29363 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 17 Oct 2025 15:04:41 -0700 Subject: [PATCH 29/79] updating comment for validation fn --- csrc/device_lower/validation.cpp | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 4373f01247a..c676f4027e7 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -436,6 +436,22 @@ class ExprValidator : public OptOutDispatch { // I'd like to check that the inner dimension of the input // is divisble by 16. + // Basic tests: + // Input is in local memory. + // Block scaling factor is in global memory and + // quantized output is in local memory. + // Any loop ID that is not TID(x/y). BID(x/y) or Group + // has an extent of 1. + // The Group ID has an extent of 4/8 depending on the data type. + // There are TIDz/BIDz IDs. + // TODO: Express the following as validation checks. + // For 1D scheduling (BIDx/TIDx/Group only): + // (1)Chek that the these loop IDs cover the entire logical domain + // (2) Group ID is "innermost" next TIDx ID and then BIDx ID. + // The above is because this op is implemented by a device function + // Which access the block scales output memory using the index + // (blockDix.x * blockDim.x + threadIdx.x) /4 (for FP32, group is 4). + // We have to do the same for 2D scheduling. void handle(BlockQuantizationOp* bqop) final { auto inp_tv = bqop->input(0)->as(); auto quantized_output = bqop->quantizedOutput()->as(); @@ -481,11 +497,19 @@ class ExprValidator : public OptOutDispatch { for (const auto& loop_id : block_scaling_factor->getLoopDomain()) { if (loop_id->getParallelType() == ParallelType::Group) { - NVF_ERROR( - grouped_id == nullptr, - "Multiple IDs found to be grouped/vectorized"); grouped_id = loop_id; } + if (loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial ID in BlockQuantizationOp"); + NVF_ERROR( + loop_id->extent()->evaluate().as() == 1, + "Expected extent of 1"); + } } auto parallel_domains_map = @@ -550,6 +574,7 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); + // Please temporarily ignore from here below. This needs to be updated. // Find the logical domain IDs that correspond to these loop IDs. // Then we check that the logical domain IDs are the inner-most // IDs. From 11fff3d3f7547cda3355d03474111dd177633f5b Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 21 Oct 2025 17:59:13 -0700 Subject: [PATCH 30/79] working index --- csrc/codegen.cpp | 9 +- csrc/device_lower/pass/index.cpp | 32 ++- csrc/device_lower/validation.cpp | 333 +++++++++++++++++--------- csrc/ir/internal_nodes.h | 3 +- csrc/ir/nodes.cpp | 2 + runtime/block_quantization_kernels.cu | 10 +- 6 files changed, 271 insertions(+), 118 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 7fd1ac2cb54..835f024d0cb 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1817,7 +1817,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // This operator is plumbed down to a runtime function call. // One of the assumptions is that the device runtime expects // 4 consecutive inputs (8 for FB16) per thread. We achieve this by having - // the input tv scheduler to have the inner dimension vectorized by 4/8. + // the input tv scheduler to have the inner dimension grouped by 4/8. auto output = bqop->quantizedOutput()->as()->view(); int64_t group_size = 1; @@ -1865,6 +1865,13 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg( genInline(bqop->blockScales()->as()->view())); + // Fourth argument: This holds the linearized index that will be used to + // write out the block scaling factors in the runtime function. + func_args.arg(genInline(bqop->attributeVal(0))); + + // Fifth argument: extent of the inner-most dimension + func_args.arg(genInline(output->getLoopDomain().back()->extent())); + indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) << ";\n"; } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 04be41024ef..2d02ee62e81 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -415,8 +415,36 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) { const auto out_quantized = IrBuilder::create( bqop->quantizedOutput()->as(), bqop->fusion()->zeroVal()); - pushBack( - IrBuilder::create(out_scales, out_quantized, in)); + // The BlockQuantizationOp funnels down to a runtime function. + // We pass the index for the block scaling factors output. We compute + // the index bases on the logical indices of the quantized output tensor. + // Then inside the runtime function, we divide this linearized index by 16 + // (the block size) to get the index for the scaling factors. + // We get the linearized index as follows: + // We get the logical indices for the quantized output. + // We then multiply and accumulate them using the logical extents of the + // quantized output tensor to get the linearized index. + std::vector logical_index = Index::getConsumerPerDimLogicalIndex( + bqop->quantizedOutput()->as(), for_loops_, getRotatedLoop()); + + auto loop_domain = + bqop->quantizedOutput()->as()->getLogicalDomain(); + + int64_t dim_count = logical_index.size(); + + // logical_index[2] * 1 + logical_index[1] * extent[2] + logical_index[0] * + // extent[1] * extent[2] + auto idx = logical_index[dim_count - 1]; + for (auto i = dim_count - 2; i >= 0; i--) { + auto stride = IrBuilder::create(1, DataType::Index); + for (auto j = i + 1; j < dim_count; j++) { + stride = IrBuilder::mulExpr(stride, loop_domain[j]->extent()); + } + idx = IrBuilder::addExpr(IrBuilder::mulExpr(logical_index[i], stride), idx); + } + + pushBack(IrBuilder::create( + out_scales, out_quantized, in, idx)); GpuLower::current()->propagateExprInfo(bqop, back()); } diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c676f4027e7..6d670e03b33 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -412,30 +413,186 @@ class ExprValidator : public OptOutDispatch { } } - // Given a set of loop domain IterDomains, find their logical domain origins - std::vector findLogicalDomainOrigins( - const std::vector& loop_domain_ids, - const TensorView* tv) { - // Get the logical domain to use as the target/boundary - const auto& logical_domain = tv->getLogicalDomain(); + // A merge is acceptable its inputs can be traced back to IDs in the logical + // domain, and that these IDs are contiguous in logical domain. + bool isAcceptableMerge(Merge* merge, TensorView* quantized_output) { + const auto logical_domain = quantized_output->getLogicalDomain(); + auto ids_in_logical_domain = IterVisitor::getInputsTo({merge->out()}); + + // Check if all elements in ids_in_logical_domain are in logical_domain + // and if they are contiguously located + bool all_in_logical_domain = true; + std::vector logical_positions; + + for (auto id : ids_in_logical_domain) { + auto iter_domain = id->as(); + auto it = + std::find(logical_domain.begin(), logical_domain.end(), iter_domain); + if (it != logical_domain.end()) { + // Found in logical domain, record its position + logical_positions.push_back(std::distance(logical_domain.begin(), it)); + } else { + all_in_logical_domain = false; + break; + } + } - // Use IterVisitor to find inputs to the loop domain IDs, - // bounded by the logical domain - std::vector inputs_as_vals = IterVisitor::getInputsTo( - {loop_domain_ids.begin(), loop_domain_ids.end()}, - {logical_domain.begin(), logical_domain.end()}); + bool are_contiguous = true; + if (all_in_logical_domain && logical_positions.size() > 1) { + // Sort positions to check contiguity + std::sort(logical_positions.begin(), logical_positions.end()); + for (size_t i = 1; i < logical_positions.size(); ++i) { + if (logical_positions[i] != logical_positions[i - 1] + 1) { + are_contiguous = false; + break; + } + } + } + + return all_in_logical_domain && are_contiguous; + } - // Convert back to IterDomains - std::vector logical_origins; - for (auto val : inputs_as_vals) { - logical_origins.push_back(val->as()); + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + // With the above example, we start from G and go up to the ID K. + // We traverse up the split only if we are coming from the inner output and we + // traverse up the merge by going to the inner input. + // While we traverse up we also store the very last split(Sp) we have seen. + // Next we want to verify if Tx follows after G. We start from the last + // split(Sp) and execute a DFS travering the inner-split first then + // outer-split. The first terminating ID we should reach should be Tx. If we + // reach a different terminating ID, then it should have an extent of 1 or + // else this is not valid. + // Details: + // TODO: relax the restriction on merges. + // We only support merge where both the inputs can be traced back to + // logical IDs that are contiguous. + Split* checkGroupIDDerivedFromLastLogicalIDs( + IterDomain* group_id, + TensorView* quantized_output) { + NVF_ERROR( + group_id != nullptr, + "Expected a valid loop grouped ID for BlockQuantizationOp: ", + quantized_output->toString()); + + auto id_val = group_id; + Split* last_split_seen = nullptr; + while (id_val->definition() != nullptr) { + auto def = id_val->definition(); + if (auto merge = dynamic_cast(def)) { + NVF_ERROR( + isAcceptableMerge(merge, quantized_output), + "Invalid merge found while tracing back the grouped ID for " + "BlockQuantizationOp. All inputs to merge must be from logical " + "domain or be outputs of other merges", + quantized_output->toString()); + id_val = merge->inner(); + } else if (auto split = dynamic_cast(def)) { + NVF_ERROR( + id_val == split->inner(), + "The grouped ID must correspond to the innermost of all splits " + "from " + "logical domains to loop domains for BlockQuantizationOp: " + "quantized output ", + quantized_output->toString()); + last_split_seen = split; + id_val = split->in(); + } else { + NVF_ERROR( + false, + "Unexpected definition found while tracing back the grouped ID for " + "BlockQuantizationOp: ", + quantized_output->toString()); + } } - return logical_origins; + NVF_ERROR( + id_val->definition() == nullptr && + id_val == quantized_output->getLogicalDomain().back(), + "The grouped ID must be the innermost logical domain ID for " + "BlockQuantizationOp: ", + quantized_output->toString()); + + return last_split_seen; + } + + void traverseFromSplitToThreadX( + IterDomain* start_traversal_id, + TensorView* quantized_output) { + std::stack id_stack; + id_stack.push(start_traversal_id); + + while (!id_stack.empty()) { + auto current_id = id_stack.top(); + id_stack.pop(); + + if (current_id->uses().size() == 0) { + // If the current_id is TIDx then great, else + // This has to have an extent of 1. + if (current_id->getParallelType() == ParallelType::TIDx) { + break; + } + NVF_ERROR( + current_id->extent()->isConstInt(), + "Expected constant extent for ID in BlockQuantizationOp"); + NVF_ERROR( + current_id->extent()->evaluate().as() == 1, + "Expected extent of 1 for ID in BlockQuantizationOp"); + continue; + } + + NVF_ERROR( + current_id->uses().size() == 1, + "Expected single use for IDs in logical to loop transforms " + "BlockQuantizationOp quantization output", + current_id->toString()); + + auto use_expr = current_id->uses().at(0); + if (auto merge = dynamic_cast(use_expr)) { + NVF_ERROR( + isAcceptableMerge(merge, quantized_output), + "Invalid merge found while tracing forward in the logical to loop " + "transforms for " + "BlockQuantizationOp quantization output ", + quantized_output->toString()); + id_stack.push(merge->out()); + } else if (auto split = dynamic_cast(use_expr)) { + id_stack.push(split->outer()); + id_stack.push(split->inner()); + } else { + NVF_ERROR( + false, + "Unexpected use of an ID found while tracing forward in the " + "logical to loop transforms for " + "BlockQuantizationOp quantization output ", + quantized_output->toString()); + } + } } - // I'd like to check that the inner dimension of the input - // is divisble by 16. // Basic tests: // Input is in local memory. // Block scaling factor is in global memory and @@ -443,15 +600,15 @@ class ExprValidator : public OptOutDispatch { // Any loop ID that is not TID(x/y). BID(x/y) or Group // has an extent of 1. // The Group ID has an extent of 4/8 depending on the data type. - // There are TIDz/BIDz IDs. - // TODO: Express the following as validation checks. - // For 1D scheduling (BIDx/TIDx/Group only): - // (1)Chek that the these loop IDs cover the entire logical domain - // (2) Group ID is "innermost" next TIDx ID and then BIDx ID. - // The above is because this op is implemented by a device function - // Which access the block scales output memory using the index - // (blockDix.x * blockDim.x + threadIdx.x) /4 (for FP32, group is 4). - // We have to do the same for 2D scheduling. + // The following are more complex checks that look at the schedule. + // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. + // Our aims for the following checks are to ensure that the Group ID is + // contiguous and unit stride, and then after Group ID, we have TIDx. + // Such that (G) * ThreadIdx.x + GID is contiguous. + // Checks for Group ID: + // Next we check that the Group ID is contiguous and unit stride. + // We walk up the logical domain to loop domains path starting from the Group + // ID (G). void handle(BlockQuantizationOp* bqop) final { auto inp_tv = bqop->input(0)->as(); auto quantized_output = bqop->quantizedOutput()->as(); @@ -490,8 +647,6 @@ class ExprValidator : public OptOutDispatch { IterDomain* grouped_id = nullptr; IterDomain* thread_x = nullptr; IterDomain* block_x = nullptr; - IterDomain* thread_y = nullptr; - IterDomain* block_y = nullptr; IterDomain* thread_z = nullptr; IterDomain* block_z = nullptr; @@ -523,14 +678,6 @@ class ExprValidator : public OptOutDispatch { parallel_domains_map.end()) { block_x = parallel_domains_map.at(ParallelType::BIDx); } - if (parallel_domains_map.find(ParallelType::TIDy) != - parallel_domains_map.end()) { - thread_y = parallel_domains_map.at(ParallelType::TIDy); - } - if (parallel_domains_map.find(ParallelType::BIDy) != - parallel_domains_map.end()) { - block_y = parallel_domains_map.at(ParallelType::BIDy); - } if (parallel_domains_map.find(ParallelType::TIDz) != parallel_domains_map.end()) { thread_z = parallel_domains_map.at(ParallelType::TIDz); @@ -557,9 +704,6 @@ class ExprValidator : public OptOutDispatch { "BlockQuantizationOp: ", bqop->toString()); - bool is_2d_scheduled = - (thread_y != nullptr || block_y != nullptr) ? true : false; - auto inner_extent = grouped_id->extent()->evaluate().as(); auto input_dtype = inp_tv->dtype(); @@ -574,86 +718,55 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); - // Please temporarily ignore from here below. This needs to be updated. - // Find the logical domain IDs that correspond to these loop IDs. - // Then we check that the logical domain IDs are the inner-most - // IDs. - auto input_logical_domains_ids = findLogicalDomainOrigins( - {grouped_id, thread_x, block_x}, block_scaling_factor); - - // Get the size of input logical domains - size_t num_input_logical_domains = input_logical_domains_ids.size(); - - // Get the same number of elements from the innermost logical domain - const auto& logical_domain = block_scaling_factor->getLogicalDomain(); - std::vector innermost_logical_domains; - - // Extract from the rightmost (innermost) positions - for (int64_t i = logical_domain.size() - 1; - i >= 0 && innermost_logical_domains.size() < num_input_logical_domains; - i--) { - auto logical_id = logical_domain[i]; - if (!logical_id->isReduction() && !logical_id->isBroadcast()) { - innermost_logical_domains.insert( - innermost_logical_domains.begin(), logical_id); + // Get the ID marked as Group + IterDomain* new_grouped_id = nullptr; + for (auto loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + new_grouped_id = loop_id; } } - // Validate that input_logical_domains_ids and innermost_logical_domains - // contain the same IterDomains - std::unordered_set input_logical_set( - input_logical_domains_ids.begin(), input_logical_domains_ids.end()); - std::unordered_set innermost_logical_set( - innermost_logical_domains.begin(), innermost_logical_domains.end()); - NVF_ERROR( - input_logical_set == innermost_logical_set, - "Input logical domain IDs do not match the innermost logical domains " - "for BlockQuantizationOp: ", - bqop->toString(), - ". Expected innermost domains: ", - toDelimitedString(innermost_logical_domains), - ". Found input logical domains: ", - toDelimitedString(input_logical_domains_ids)); - - // If it's 2D scheduled, the we get the IDs from the logical domain - // that correspond to blockIdx.y and threadIdx.y. We make sure the - // IDs from the logical domain don't share any ID with those from the - // thread/block for x-dimension was derived. - if (is_2d_scheduled) { - std::vector input_logical_domains_ids_2d = {}; - for (auto id : {thread_y, block_y}) { - if (id) { - input_logical_domains_ids_2d.push_back(id); - } - } + new_grouped_id != nullptr, + "Expected a valid loop grouped ID for BlockQuantizationOp: ", + bqop->toString()); - auto input_logical_domains_ids_y = findLogicalDomainOrigins( - input_logical_domains_ids_2d, block_scaling_factor); + auto last_split_seen = + checkGroupIDDerivedFromLastLogicalIDs(new_grouped_id, quantized_output); - // Validate that input_logical_domains_ids and input_logical_domains_ids_y - // don't have any elements in common - std::unordered_set input_logical_set_x( - input_logical_domains_ids.begin(), input_logical_domains_ids.end()); - std::unordered_set input_logical_set_y( - input_logical_domains_ids_y.begin(), - input_logical_domains_ids_y.end()); + // if last split seen is null there are two possibilities + // 1) Group ID is directly from logical domain -> valid + // 2) There was a merge right before Group ID + IterDomain* restart_traversal_from = nullptr; - for (const auto& id : input_logical_set_x) { + if (last_split_seen == nullptr) { + auto ids_in_logical = IterVisitor::getInputsTo({new_grouped_id}); + // Check all these ID have constant extents + for (auto id : ids_in_logical) { + auto iter_domain = id->as(); NVF_ERROR( - input_logical_set_y.find(id) == input_logical_set_y.end(), - "Input logical domain IDs for X and Y dimensions have overlapping " - "elements " - "for BlockQuantizationOp: ", - bqop->toString(), - ". Overlapping IterDomain: ", - id->toString(), - ". X logical domains: ", - toDelimitedString(input_logical_domains_ids), - ". Y logical domains: ", - toDelimitedString(input_logical_domains_ids_y)); + iter_domain->extent()->isConstInt(), + "Expected all IDs feeding directly into Group ID to have constant " + "extents for BlockQuantizationOp: ", + quantized_output->toString()); } + + // Check that there are logical IDs left to derive thread IDs + NVF_ERROR( + ids_in_logical.size() < quantized_output->getLogicalDomain().size(), + "There aren't enough logical IDs to derive thread Ids ", + quantized_output->toString()); + + restart_traversal_from = + quantized_output->getLogicalDomain() + [quantized_output->getLogicalDomain().size() - + ids_in_logical.size() - 1]; + } else { + // Go the outer ID, we should have come up from the inner split. + restart_traversal_from = last_split_seen->outer(); } + + traverseFromSplitToThreadX(restart_traversal_from, quantized_output); } }; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 2a5a7e4da52..38d5e877bcd 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -3408,6 +3408,7 @@ class BlockQuantizationOp : public Expr { Val* output_scales, Val* output, Val* input, + Val* logical_index = nullptr, Val* global_scale = nullptr, int64_t block_size = 16); @@ -3426,7 +3427,7 @@ class BlockQuantizationOp : public Expr { } int64_t blockSize() const { - return attribute(0); + return attribute(1); } bool hasGlobalScale() const { diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 6b6059ea64d..233fe26920e 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -6417,6 +6417,7 @@ BlockQuantizationOp::BlockQuantizationOp( Val* output_scales, Val* output, Val* input, + Val* logical_index, Val* global_scale, int64_t block_size) : Expr(passkey) { @@ -6426,6 +6427,7 @@ BlockQuantizationOp::BlockQuantizationOp( if (global_scale) { addInput(global_scale); } + addAttribute(logical_index); addDataAttribute(block_size); } diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index f67ae04bf44..7d713fe772b 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -48,7 +48,9 @@ template < __device__ void block_quantize_to_nvfp4( Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, - Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output) { + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output, + nvfuser_index_t logical_index, + int input_logical_inner_dim_size) { constexpr bool is_half_or_bfloat = std::is_same::value || std::is_same::value; constexpr bool is_float = std::is_same::value; @@ -61,7 +63,6 @@ __device__ void block_quantize_to_nvfp4( } else if constexpr (is_half_or_bfloat) { assert(blockDim.x % 2 == 0); } - assert(blockDim.z == 1 && gridDim.z == 1); static_assert( (is_float && ITEMS_PER_THREAD == 4) || @@ -69,6 +70,8 @@ __device__ void block_quantize_to_nvfp4( "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat or __half " "type"); + assert(input_logical_inner_dim_size % 16 == 0); + int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; Array vec_in; @@ -110,8 +113,7 @@ __device__ void block_quantize_to_nvfp4( int offset_dim_y = threadIdx.y * blockDim.x * gridDim.x; int offset_into_block = blockIdx.x * blockDim.x + threadIdx.x; - int offset = (offset_y_blocks + offset_dim_y + offset_into_block) / - THREADS_PER_SCALING_FACTOR; + int offset = logical_index / 16; // Convert back from FP8 to float using __e4m32float if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { From c729ada9b15108194dec2da4b83fad40ca25662a Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 23 Oct 2025 16:04:33 -0700 Subject: [PATCH 31/79] remove header --- csrc/device_lower/validation.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 6d670e03b33..e256b1607d0 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include From 535075dfabf4812a31e1005d95fde53b0f4353e2 Mon Sep 17 00:00:00 2001 From: Protonu Date: Thu, 30 Oct 2025 20:21:42 -0400 Subject: [PATCH 32/79] Update csrc/codegen.cpp reviewer comments Co-authored-by: Naoya Maruyama --- csrc/codegen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 835f024d0cb..68531836135 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1821,7 +1821,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto output = bqop->quantizedOutput()->as()->view(); int64_t group_size = 1; - // Get the loop domain of the TensorView output and check for group/vector + // Get the loop domain of the TensorView output and check for group // parallel types. This assumes that both parallel types aren't present. const auto& loop_domain = output->getLoopDomain(); for (auto* domain : loop_domain) { From 115ddbfeaaaa498b77c83ff901a1110c98d84b1d Mon Sep 17 00:00:00 2001 From: Protonu Date: Fri, 31 Oct 2025 09:50:36 -0400 Subject: [PATCH 33/79] Update runtime/block_quantization_kernels.cu reviewer suggestion Co-authored-by: Naoya Maruyama --- runtime/block_quantization_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 7d713fe772b..83ca708513c 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -46,7 +46,7 @@ template < int BLOCK_SCALE_DIM, int BLOCK_SCALE_ALLOC> __device__ void block_quantize_to_nvfp4( - Array& input, + const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output, nvfuser_index_t logical_index, From 636303293bf640489a83e28d878b705aa5096569 Mon Sep 17 00:00:00 2001 From: Protonu Date: Fri, 31 Oct 2025 09:50:59 -0400 Subject: [PATCH 34/79] Update runtime/block_quantization_kernels.cu reviewer suggestion. Co-authored-by: Naoya Maruyama --- runtime/block_quantization_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 83ca708513c..aa0f0d3affb 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -48,7 +48,7 @@ template < __device__ void block_quantize_to_nvfp4( const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, - Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& fp8_output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, nvfuser_index_t logical_index, int input_logical_inner_dim_size) { constexpr bool is_half_or_bfloat = From 65c8f25bb2889902f382806b0efac2f589c4166d Mon Sep 17 00:00:00 2001 From: Protonu Date: Fri, 31 Oct 2025 09:51:19 -0400 Subject: [PATCH 35/79] Update runtime/block_quantization_kernels.cu Co-authored-by: Naoya Maruyama --- runtime/block_quantization_kernels.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index aa0f0d3affb..3633fa05524 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -95,9 +95,8 @@ __device__ void block_quantize_to_nvfp4( // Perform block(16 elements)-wide reduction (max) // across 4- threads - float block_max = NEG_INFINITY; localMaxReduction(local_max); - block_max = local_max; + float block_max = local_max; // This division should be replaced with a multiplication // by a reciprocal for better performance. From 1f9829a691ae45106591082e7a27330344a592c7 Mon Sep 17 00:00:00 2001 From: Protonu Date: Fri, 31 Oct 2025 10:02:54 -0400 Subject: [PATCH 36/79] Update csrc/ops/arith.cpp Co-authored-by: Naoya Maruyama --- csrc/ops/arith.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 90b1e73fd34..6941daefe07 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2667,13 +2667,13 @@ BlockQuantizationResults blockQuantize( std::vector scales_out_domain; scales_out_domain.reserve(inp_domain.size()); - for (size_t i = 0; i < inp_domain.size(); ++i) { - if (i == inp_domain.size() - 1) { + for (auto inp_id: inp_domain) { + if (inp_id == inp_domain.back()) { scales_out_domain.push_back( IterDomainBuilder( - inp_domain[i]->start(), + inp_id->start(), SimplifyingIrBuilder::divExpr( - inp_domain[i]->extent(), + inp_id->extent(), IrBuilder::create(block_size, DataType::Index))) .build()); From aaa591ec766225c7e60c951a3d30013a7adda08c Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 08:34:39 -0700 Subject: [PATCH 37/79] address reviewer comments --- csrc/codegen.cpp | 3 - csrc/ir/internal_nodes.h | 7 +++ csrc/ops/arith.cpp | 4 +- runtime/block_quantization_kernels.cu | 79 +++++++++++++++------------ 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 68531836135..26ad8fe6b90 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1869,9 +1869,6 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // write out the block scaling factors in the runtime function. func_args.arg(genInline(bqop->attributeVal(0))); - // Fifth argument: extent of the inner-most dimension - func_args.arg(genInline(output->getLoopDomain().back()->extent())); - indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args) << ";\n"; } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 38d5e877bcd..d47b731161f 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -3403,6 +3403,13 @@ class BlockQuantizationOp : public Expr { public: using Expr::Expr; + // This op takes in a high precision input(input) + // and returns the quantized output(output) along with the block scaling + // factors (output_scales). It can also take as an optional input the global + // scaling factor and block size (though we currently only support 16). + // logical_index is used for internal implemtation. This op is currently + // implemented via a runtime function. During index computation, we compute + // the index of the output_scales and pass it to the runtime function. BlockQuantizationOp( IrBuilderPasskey, Val* output_scales, diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 6941daefe07..5781f0d317b 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2667,7 +2667,7 @@ BlockQuantizationResults blockQuantize( std::vector scales_out_domain; scales_out_domain.reserve(inp_domain.size()); - for (auto inp_id: inp_domain) { + for (auto inp_id : inp_domain) { if (inp_id == inp_domain.back()) { scales_out_domain.push_back( IterDomainBuilder( @@ -2678,7 +2678,7 @@ BlockQuantizationResults blockQuantize( .build()); } else { - scales_out_domain.push_back(inp_domain[i]->cloneWithoutRFactor()); + scales_out_domain.push_back(inp_id->cloneWithoutRFactor()); } } diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 3633fa05524..788372533de 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,8 +9,21 @@ namespace nvf { namespace bq { +// This helper function is templatized of over types float, __half, and +// __bfloat. This assumes that for float, each thread was working on 4 elements. +// Thus 4 threads were working to find the max of 16 elements, and hence we need +// two steps to find the maximum. If the type is __bfloat or __half, then we +// only need a single step to find the maximum of 16 elements as each thread was +// working on 8 elements and 2 threads are required to compute the max of 16 +// elements. +// This function assumes for float each thread has already computed the max of 4 +// elements (8 elements for the other 2 data types) and the block size is 16, so +// we have 4 threads (2 for bf16/fp16) participating in the reduction. +// TODO: For FP32 support the cases where each thread works on 2 or 4 elements. +// TODO: For bf16/fp16 support the cases where each thread works on 2,4, or 8 +// elements. template -__device__ __inline__ void localMaxReduction(float& local_max) { +__device__ __inline__ void reduceAcrossThreads(float& per_thread_computed_max) { // The mask 0xffffffff indicates all 32 threads in the warp are participating. unsigned int mask = 0xffffffff; @@ -19,25 +32,30 @@ __device__ __inline__ void localMaxReduction(float& local_max) { // e.g., thread 0 exchanges with 2; thread 1 with 3. // The XOR pattern naturally keeps the operation within each quad. if (std::is_same::value) { - local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 2)); + per_thread_computed_max = fmax( + per_thread_computed_max, + __shfl_xor_sync(mask, per_thread_computed_max, 2)); } // --- Reduction Step 2 --- // Exchange and compare with thread 1 lane away. // e.g., thread 0 exchanges with 1; thread 2 with 3. - local_max = fmax(local_max, __shfl_xor_sync(mask, local_max, 1)); + per_thread_computed_max = fmax( + per_thread_computed_max, + __shfl_xor_sync(mask, per_thread_computed_max, 1)); - // At this point, all threads in a quad hold the maximum value for that quad. + // At this point, all threads in a quad hold the maximum value for that + // quad(pair of 2 threads). } -// TODO: Add a template parameter for input type. -// For now we just work on float. -// This also assumes a block of 16. That should be a -// template parameter. - -// This assumes that ITEMS_PER_THREAD is 4. -// This assumes for block quantization, the block size is 16. -// This works for float but will extended to work with bfloat. +// A runtime function to compute quantized nvfp4 output (output) and fp8 block +// scaling (block_scales) factors from fp32, fp16, bf16 inputs (input). +// The function is templatized over input type T (float, __half, __bfloat). +// This function assumes that for float, each thread is working on 4 elements. +// Thus 4 threads are working to quantize 16 elements. If the type is __bfloat +// or +// __half, then 2 threads are working to quantize 16 elements as each thread +// is working on 8 elements. template < int ITEMS_PER_THREAD, typename T, @@ -49,8 +67,7 @@ __device__ void block_quantize_to_nvfp4( const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index, - int input_logical_inner_dim_size) { + nvfuser_index_t logical_index) { constexpr bool is_half_or_bfloat = std::is_same::value || std::is_same::value; constexpr bool is_float = std::is_same::value; @@ -58,25 +75,19 @@ __device__ void block_quantize_to_nvfp4( is_float || is_half_or_bfloat, "Input type must be float, __half or __bfloat"); - if constexpr (is_float) { - assert(blockDim.x % 4 == 0); - } else if constexpr (is_half_or_bfloat) { - assert(blockDim.x % 2 == 0); - } - static_assert( (is_float && ITEMS_PER_THREAD == 4) || (is_half_or_bfloat && ITEMS_PER_THREAD == 8), "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat or __half " "type"); - assert(input_logical_inner_dim_size % 16 == 0); - - int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; + // Number of threads involved in computing one block scaling factor + constexpr int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; Array vec_in; - vec_in.set(0.0f); // Initialize to zero like nvfuser does + vec_in.set(0.0f); +#pragma unroll for (auto i = 0; i < ITEMS_PER_THREAD; i++) { if constexpr (std::is_same::value) { vec_in[i] = input[i]; @@ -93,9 +104,10 @@ __device__ void block_quantize_to_nvfp4( local_max = fmax(local_max, fabsf(vec_in[i])); } - // Perform block(16 elements)-wide reduction (max) - // across 4- threads - localMaxReduction(local_max); + // Compute the max accross 4 threads (float) or 2 threads (bf16/fp16) + // This assumes each thread has already computed is local max of 4 (fp32) or + // 8 (bf16/fp16) elements. + reduceAcrossThreads(local_max); float block_max = local_max; // This division should be replaced with a multiplication @@ -106,17 +118,16 @@ __device__ void block_quantize_to_nvfp4( __e4m3 clamped_max_fp8 = __float2e4m3(clamped_max); + // Convert back from FP8 to float using __e4m32float float clamped_max_converted = __e4m32float(clamped_max_fp8); - int offset_y_blocks = blockIdx.y * blockDim.y * blockDim.x * gridDim.x; - int offset_dim_y = threadIdx.y * blockDim.x * gridDim.x; - int offset_into_block = blockIdx.x * blockDim.x + threadIdx.x; - + // Write out the block scaling factor to global memory. + // This assumes 16 elements in the input were contiguous. + // Only one block scaling factor is written out per 16(assumed block size) + // elements. int offset = logical_index / 16; - - // Convert back from FP8 to float using __e4m32float if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { - fp8_output[offset] = clamped_max_fp8; // Broadcast to all threads + block_scales[offset] = clamped_max_fp8; } Array clamped_vals; From 64a921e1691a7de705a78c48123430515446d88b Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 10:58:34 -0700 Subject: [PATCH 38/79] runtime validation for inner dim size --- csrc/device_lower/pass/index.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 2d02ee62e81..75fcccb4dc6 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -443,6 +443,19 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) { idx = IrBuilder::addExpr(IrBuilder::mulExpr(logical_index[i], stride), idx); } + // As part of runtime validation + // make sure that the inner dimension of the input is divisible by 16. + auto* inner_id = bqop->in()->as()->getLogicalDomain().back(); + Val* is_divisible = SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr( + inner_id->extent(), IrBuilder::create(16)), + bqop->fusion()->zeroVal()); + + NVFUSER_LOWER_VALIDATE( + is_divisible, + "Inner dim of input of Block Quantization is not divisble by 16", + bqop->toString()); + pushBack(IrBuilder::create( out_scales, out_quantized, in, idx)); GpuLower::current()->propagateExprInfo(bqop, back()); From 7e3835ccc70b67cf2586ef782843bf2fbc7cf8cd Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 12:45:56 -0700 Subject: [PATCH 39/79] address reviewer comments --- csrc/device_lower/validation.cpp | 3 +-- csrc/ops/arith.cpp | 6 +++--- csrc/ops/arith.h | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index e256b1607d0..d7e28d1f179 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1392,8 +1392,7 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) { (def->isA() && def->as()->serialGridReductionRequested()) || (def->isA() && - def->as()->getUnaryOpType() == UnaryOpType::Cast) || - def->isA(), + def->as()->getUnaryOpType() == UnaryOpType::Cast), "Vectorized accesses cannot be inline with computation: ", (def == nullptr ? tv->toString() : def->toString())); } diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 5781f0d317b..ec9598e6c27 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2633,8 +2633,8 @@ BlockQuantizationResults blockQuantize( block_size); NVF_CHECK( - out_dtype == DataType::Float4_e2m1fn_x2, - "Currently only output data type of Float4_e2m1fn_x2 is supported"); + out_dtype == DataType::Float4_e2m1fn, + "Currently only output data type of Float4_e2m1fn is supported"); // Validate input data type // We'll only support FP32 or BF16 @@ -2687,7 +2687,7 @@ BlockQuantizationResults blockQuantize( IrBuilder::create( quantized_out_domain, TensorDomain::getContiguityFilledWith(quantized_out_domain, true)), - DataType::Float4_e2m1fn); // Quantized output using 32-bit integers + out_dtype); TensorView* block_scales = IrBuilder::create( IrBuilder::create( diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 49876d6111a..9002e18a114 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -845,6 +845,6 @@ struct BlockQuantizationResults { NVF_API BlockQuantizationResults blockQuantize( TensorView* input, int64_t block_size = 16, - DataType out_dtype = DataType::Float4_e2m1fn_x2); + DataType out_dtype = DataType::Float4_e2m1fn); } // namespace nvfuser From 9dd3a7af392dc9fbb232fdd26198eaef468d0c15 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 12:53:21 -0700 Subject: [PATCH 40/79] removing code for validation --- csrc/device_lower/validation.cpp | 356 ------------------------------- 1 file changed, 356 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index d7e28d1f179..608c3a0330d 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -411,362 +411,6 @@ class ExprValidator : public OptOutDispatch { (*it)->toString()); } } - - // A merge is acceptable its inputs can be traced back to IDs in the logical - // domain, and that these IDs are contiguous in logical domain. - bool isAcceptableMerge(Merge* merge, TensorView* quantized_output) { - const auto logical_domain = quantized_output->getLogicalDomain(); - auto ids_in_logical_domain = IterVisitor::getInputsTo({merge->out()}); - - // Check if all elements in ids_in_logical_domain are in logical_domain - // and if they are contiguously located - bool all_in_logical_domain = true; - std::vector logical_positions; - - for (auto id : ids_in_logical_domain) { - auto iter_domain = id->as(); - auto it = - std::find(logical_domain.begin(), logical_domain.end(), iter_domain); - if (it != logical_domain.end()) { - // Found in logical domain, record its position - logical_positions.push_back(std::distance(logical_domain.begin(), it)); - } else { - all_in_logical_domain = false; - break; - } - } - - bool are_contiguous = true; - if (all_in_logical_domain && logical_positions.size() > 1) { - // Sort positions to check contiguity - std::sort(logical_positions.begin(), logical_positions.end()); - for (size_t i = 1; i < logical_positions.size(); ++i) { - if (logical_positions[i] != logical_positions[i - 1] + 1) { - are_contiguous = false; - break; - } - } - } - - return all_in_logical_domain && are_contiguous; - } - - // M K - // │ │ - // ▼ ▼ - // ┌────────────┐ - // │ merge │ - // └─────┬──────┘ - // │ - // ▼ - // M*K - // ┌──────────┐ - // │ split ┼──┐ - // └─┬────────┘ │ - // ▼ ▼ - // (M*K)/4 4(G) - // ┌────────┐ - // │ split ┼────┐ - // └─┬──────┘ │ - // ▼ ▼ - // (M*K)/4 1(U) - // ┌─────────┐ - // │ split │ - // ┌─┼ ┼───┐ - // │ └─────────┘ │ - // ▼ ▼ - // (M*K)/4/128 128(Tx) - // With the above example, we start from G and go up to the ID K. - // We traverse up the split only if we are coming from the inner output and we - // traverse up the merge by going to the inner input. - // While we traverse up we also store the very last split(Sp) we have seen. - // Next we want to verify if Tx follows after G. We start from the last - // split(Sp) and execute a DFS travering the inner-split first then - // outer-split. The first terminating ID we should reach should be Tx. If we - // reach a different terminating ID, then it should have an extent of 1 or - // else this is not valid. - // Details: - // TODO: relax the restriction on merges. - // We only support merge where both the inputs can be traced back to - // logical IDs that are contiguous. - Split* checkGroupIDDerivedFromLastLogicalIDs( - IterDomain* group_id, - TensorView* quantized_output) { - NVF_ERROR( - group_id != nullptr, - "Expected a valid loop grouped ID for BlockQuantizationOp: ", - quantized_output->toString()); - - auto id_val = group_id; - Split* last_split_seen = nullptr; - while (id_val->definition() != nullptr) { - auto def = id_val->definition(); - if (auto merge = dynamic_cast(def)) { - NVF_ERROR( - isAcceptableMerge(merge, quantized_output), - "Invalid merge found while tracing back the grouped ID for " - "BlockQuantizationOp. All inputs to merge must be from logical " - "domain or be outputs of other merges", - quantized_output->toString()); - id_val = merge->inner(); - } else if (auto split = dynamic_cast(def)) { - NVF_ERROR( - id_val == split->inner(), - "The grouped ID must correspond to the innermost of all splits " - "from " - "logical domains to loop domains for BlockQuantizationOp: " - "quantized output ", - quantized_output->toString()); - last_split_seen = split; - id_val = split->in(); - } else { - NVF_ERROR( - false, - "Unexpected definition found while tracing back the grouped ID for " - "BlockQuantizationOp: ", - quantized_output->toString()); - } - } - - NVF_ERROR( - id_val->definition() == nullptr && - id_val == quantized_output->getLogicalDomain().back(), - "The grouped ID must be the innermost logical domain ID for " - "BlockQuantizationOp: ", - quantized_output->toString()); - - return last_split_seen; - } - - void traverseFromSplitToThreadX( - IterDomain* start_traversal_id, - TensorView* quantized_output) { - std::stack id_stack; - id_stack.push(start_traversal_id); - - while (!id_stack.empty()) { - auto current_id = id_stack.top(); - id_stack.pop(); - - if (current_id->uses().size() == 0) { - // If the current_id is TIDx then great, else - // This has to have an extent of 1. - if (current_id->getParallelType() == ParallelType::TIDx) { - break; - } - NVF_ERROR( - current_id->extent()->isConstInt(), - "Expected constant extent for ID in BlockQuantizationOp"); - NVF_ERROR( - current_id->extent()->evaluate().as() == 1, - "Expected extent of 1 for ID in BlockQuantizationOp"); - continue; - } - - NVF_ERROR( - current_id->uses().size() == 1, - "Expected single use for IDs in logical to loop transforms " - "BlockQuantizationOp quantization output", - current_id->toString()); - - auto use_expr = current_id->uses().at(0); - if (auto merge = dynamic_cast(use_expr)) { - NVF_ERROR( - isAcceptableMerge(merge, quantized_output), - "Invalid merge found while tracing forward in the logical to loop " - "transforms for " - "BlockQuantizationOp quantization output ", - quantized_output->toString()); - id_stack.push(merge->out()); - } else if (auto split = dynamic_cast(use_expr)) { - id_stack.push(split->outer()); - id_stack.push(split->inner()); - } else { - NVF_ERROR( - false, - "Unexpected use of an ID found while tracing forward in the " - "logical to loop transforms for " - "BlockQuantizationOp quantization output ", - quantized_output->toString()); - } - } - } - - // Basic tests: - // Input is in local memory. - // Block scaling factor is in global memory and - // quantized output is in local memory. - // Any loop ID that is not TID(x/y). BID(x/y) or Group - // has an extent of 1. - // The Group ID has an extent of 4/8 depending on the data type. - // The following are more complex checks that look at the schedule. - // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. - // Our aims for the following checks are to ensure that the Group ID is - // contiguous and unit stride, and then after Group ID, we have TIDx. - // Such that (G) * ThreadIdx.x + GID is contiguous. - // Checks for Group ID: - // Next we check that the Group ID is contiguous and unit stride. - // We walk up the logical domain to loop domains path starting from the Group - // ID (G). - void handle(BlockQuantizationOp* bqop) final { - auto inp_tv = bqop->input(0)->as(); - auto quantized_output = bqop->quantizedOutput()->as(); - auto block_scaling_factor = bqop->blockScales()->as(); - - NVF_ERROR_EQ( - inp_tv->getMemoryType(), - MemoryType::Local, - "Input must be a local memory tensor. Found: ", - inp_tv->getMemoryType()); - - NVF_ERROR_EQ( - quantized_output->getMemoryType(), - MemoryType::Local, - "Quantized output must be a local memory tensor. Found: ", - quantized_output->getMemoryType()); - - NVF_ERROR_EQ( - block_scaling_factor->getMemoryType(), - MemoryType::Global, - "Block scaling factor must be a global memory tensor. Found: ", - block_scaling_factor->getMemoryType()); - - // outputs have the same allocation domain - // as the loop domain. This has to be later - // relaxed for the scaling factors. - NVF_ERROR( - quantized_output->hasAllocation() == false, - "Quantized output must not have an allocation domain."); - NVF_ERROR( - block_scaling_factor->hasAllocation() == false, - "Block scaling factor must not have an allocation domain."); - - // Check that it either had vectorized ID or grouped ID - // not both and the extent is either 4(FP32) or 8(BF16) - IterDomain* grouped_id = nullptr; - IterDomain* thread_x = nullptr; - IterDomain* block_x = nullptr; - IterDomain* thread_z = nullptr; - IterDomain* block_z = nullptr; - - for (const auto& loop_id : block_scaling_factor->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group) { - grouped_id = loop_id; - } - if (loop_id->getParallelType() == ParallelType::Serial || - loop_id->getParallelType() == ParallelType::Unswitch || - loop_id->getParallelType() == ParallelType::Unroll) { - // Check this is ID has a constant extent and is 1 - NVF_ERROR( - loop_id->extent()->isConstInt(), - "Expected constant extent for Serial ID in BlockQuantizationOp"); - NVF_ERROR( - loop_id->extent()->evaluate().as() == 1, - "Expected extent of 1"); - } - } - - auto parallel_domains_map = - ir_utils::getParallelDomains(block_scaling_factor); - - if (parallel_domains_map.find(ParallelType::TIDx) != - parallel_domains_map.end()) { - thread_x = parallel_domains_map.at(ParallelType::TIDx); - } - if (parallel_domains_map.find(ParallelType::BIDx) != - parallel_domains_map.end()) { - block_x = parallel_domains_map.at(ParallelType::BIDx); - } - if (parallel_domains_map.find(ParallelType::TIDz) != - parallel_domains_map.end()) { - thread_z = parallel_domains_map.at(ParallelType::TIDz); - } - if (parallel_domains_map.find(ParallelType::BIDz) != - parallel_domains_map.end()) { - block_z = parallel_domains_map.at(ParallelType::BIDz); - } - - NVF_ERROR( - grouped_id != nullptr, - "One of the output IDs must be grouped for " - "BlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - thread_x != nullptr && block_x != nullptr, - "Need to have both TIDx and BIDx when using BlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - !thread_z && !block_z, - "Parallelization along z axis is not supported for " - "BlockQuantizationOp: ", - bqop->toString()); - - auto inner_extent = grouped_id->extent()->evaluate().as(); - auto input_dtype = inp_tv->dtype(); - - NVF_ERROR( - (inner_extent == 4 && input_dtype == DataType::Float) || - (inner_extent == 8 && - (input_dtype == DataType::BFloat16 || - input_dtype == DataType::Half)), - "The vectorized/grouped dimension must be 4 (FP32) or 8 " - "(BF16). Found: ", - inner_extent, - ". Expr: ", - bqop->toString()); - - // Get the ID marked as Group - IterDomain* new_grouped_id = nullptr; - for (auto loop_id : quantized_output->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group) { - new_grouped_id = loop_id; - } - } - - NVF_ERROR( - new_grouped_id != nullptr, - "Expected a valid loop grouped ID for BlockQuantizationOp: ", - bqop->toString()); - - auto last_split_seen = - checkGroupIDDerivedFromLastLogicalIDs(new_grouped_id, quantized_output); - - // if last split seen is null there are two possibilities - // 1) Group ID is directly from logical domain -> valid - // 2) There was a merge right before Group ID - IterDomain* restart_traversal_from = nullptr; - - if (last_split_seen == nullptr) { - auto ids_in_logical = IterVisitor::getInputsTo({new_grouped_id}); - // Check all these ID have constant extents - for (auto id : ids_in_logical) { - auto iter_domain = id->as(); - NVF_ERROR( - iter_domain->extent()->isConstInt(), - "Expected all IDs feeding directly into Group ID to have constant " - "extents for BlockQuantizationOp: ", - quantized_output->toString()); - } - - // Check that there are logical IDs left to derive thread IDs - NVF_ERROR( - ids_in_logical.size() < quantized_output->getLogicalDomain().size(), - "There aren't enough logical IDs to derive thread Ids ", - quantized_output->toString()); - - restart_traversal_from = - quantized_output->getLogicalDomain() - [quantized_output->getLogicalDomain().size() - - ids_in_logical.size() - 1]; - } else { - // Go the outer ID, we should have come up from the inner split. - restart_traversal_from = last_split_seen->outer(); - } - - traverseFromSplitToThreadX(restart_traversal_from, quantized_output); - } }; } // namespace From a444ea7eda6a5aa816d231edad1bc80b3b1c80ea Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 14:35:14 -0700 Subject: [PATCH 41/79] update comments --- csrc/ops/arith.cpp | 3 ++- csrc/ops/arith.h | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index ec9598e6c27..81db6facb3b 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2689,11 +2689,12 @@ BlockQuantizationResults blockQuantize( TensorDomain::getContiguityFilledWith(quantized_out_domain, true)), out_dtype); + // Create block scaling factors TensorView* block_scales = IrBuilder::create( IrBuilder::create( scales_out_domain, TensorDomain::getContiguityFilledWith(scales_out_domain, true)), - DataType::Float8_e4m3fn); // Scales maintain input data type + DataType::Float8_e4m3fn); // Create the block quantization operation IrBuilder::create(block_scales, quantized_tensor, input); diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 9002e18a114..181ab49d3f1 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -838,10 +838,11 @@ struct BlockQuantizationResults { : quantized_tensor(in_quantized_tensor), block_scales(in_block_scales) {} }; -//! TODO: Expose global scaling factor -// API for block quantization to nvFP4. -// We take FP32 or BF16 input and produce two outputs -// nvFP4(x2) outputs and FP8 block scales. +// API for block quantization. +// Currently We take FP32 or BF16/FP16 input and produce two outputs: +// nvFP4 outputs and FP8 block scales. +// We optionally take a block size as an input but currenlty just support 16. +// TODO: Expose global scaling factor NVF_API BlockQuantizationResults blockQuantize( TensorView* input, int64_t block_size = 16, From 375feae88605f490d3e345d9ed8f16e9acc83ccd Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 31 Oct 2025 15:02:48 -0700 Subject: [PATCH 42/79] edit comments --- csrc/ops/arith.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 81db6facb3b..63c41768fc6 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2620,9 +2620,9 @@ TensorView* prefixSum(TensorView* tv, int64_t dim) { /*init=*/tv->fusion()->zeroVal(tv->dtype())); } -// Currently this node get lowered to a runtime function, which expects the -// input in registers and write out quantized values or registers and block -// scales to global memory. +// Currently this node gets lowered to a runtime function, which expects the +// inputs in registers and writes out the quantized values to registers and +// block scales to global memory. BlockQuantizationResults blockQuantize( TensorView* input, int64_t block_size, @@ -2637,8 +2637,7 @@ BlockQuantizationResults blockQuantize( "Currently only output data type of Float4_e2m1fn is supported"); // Validate input data type - // We'll only support FP32 or BF16 - // We should check if the inputs are FP or BF16. + // We'll only support FP32 or BF16/FP16 NVF_CHECK( input->getDataType().value() == DataType::Float || input->getDataType().value() == DataType::BFloat16 || From d24562ad78de3bf27e02a752a21a7cc4c12cddd0 Mon Sep 17 00:00:00 2001 From: protonu Date: Sat, 1 Nov 2025 06:09:29 -0700 Subject: [PATCH 43/79] adding validation checks and initial tests --- csrc/device_lower/validation.cpp | 367 ++++++++++++++++++++++++ tests/cpp/test_low_precision_recipe.cpp | 316 ++++++++++++++++++++ 2 files changed, 683 insertions(+) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 608c3a0330d..be09b0e9883 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -261,6 +261,246 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { return !frontier.empty() && frontier.back() == maybe_innermost_id; } +// Helper class for BlockQuantization validation +// +// This class validates the scheduling requirements for BlockQuantizationOp: +// 1. The Group ID must be derived from the innermost logical IDs +// 2. Merge operations must combine contiguous logical IDs +// 3. TIDx must follow the Group ID in the schedule +class BlockQuantizationValidationHelper { + public: + BlockQuantizationValidationHelper(const TensorView* tv, IterDomain* group_id) + : tv_(tv), group_id_(group_id) {} + + // Validates the complete scheduling structure for block quantization + void run() { + traceGroupIdToLogicalDomain(); + IterDomain* restart_traversal_from = determineThreadXTraversalStart(); + validateThreadXFollowsGroupId(restart_traversal_from); + } + + private: + // Checks if a merge operation combines contiguous IDs from the logical + // domain. A merge is acceptable if its inputs can be traced back to IDs in + // the logical domain and those IDs are contiguous. + bool isAcceptableMerge(Merge* merge) const { + const auto& logical_domain = tv_->getLogicalDomain(); + auto ids_in_logical_domain = IterVisitor::getInputsTo({merge->out()}); + + // Collect positions of IDs in the logical domain + std::vector logical_positions; + for (auto id : ids_in_logical_domain) { + auto iter_domain = id->as(); + auto it = + std::find(logical_domain.begin(), logical_domain.end(), iter_domain); + if (it == logical_domain.end()) { + return false; // ID not found in logical domain + } + logical_positions.push_back(std::distance(logical_domain.begin(), it)); + } + + // Check contiguity: positions should be consecutive when sorted + if (logical_positions.size() > 1) { + std::sort(logical_positions.begin(), logical_positions.end()); + for (size_t i = 1; i < logical_positions.size(); ++i) { + if (logical_positions[i] != logical_positions[i - 1] + 1) { + return false; // Not contiguous + } + } + } + + return true; + } + + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + // + // Traces the Group ID backwards through splits and merges to ensure it + // derives from the innermost logical IDs. The diagram above shows a typical + // transformation chain from logical IDs (M, K) to the Group ID (G). + // + // Traversal rules: + // - For splits: Only traverse through the inner output (Group ID must be + // innermost) + // - For merges: Traverse through the inner input and validate contiguity + // + // Returns via last_split_seen_: The last split encountered during traversal, + // which is used as the starting point for validating the TIDx path. + void traceGroupIdToLogicalDomain() { + NVF_ERROR( + group_id_ != nullptr, + "Expected a valid loop grouped ID for BlockQuantizationOp: ", + tv_->toString()); + + auto current_id = group_id_; + last_split_seen_ = nullptr; + + while (current_id->definition() != nullptr) { + auto def = current_id->definition(); + + if (auto merge = dynamic_cast(def)) { + NVF_ERROR( + isAcceptableMerge(merge), + "Invalid merge found while tracing back the grouped ID for " + "BlockQuantizationOp. All inputs to merge must be from logical " + "domain or be outputs of other merges. TV: ", + tv_->toString()); + current_id = merge->inner(); + } else if (auto split = dynamic_cast(def)) { + NVF_ERROR( + current_id == split->inner(), + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for BlockQuantizationOp. " + "TV: ", + tv_->toString()); + last_split_seen_ = split; + current_id = split->in(); + } else { + NVF_ERROR( + false, + "Unexpected definition found while tracing back the grouped ID for " + "BlockQuantizationOp: ", + tv_->toString()); + } + } + + NVF_ERROR( + current_id->definition() == nullptr && + current_id == tv_->getLogicalDomain().back(), + "The grouped ID must be the innermost logical domain ID for " + "BlockQuantizationOp: ", + tv_->toString()); + } + + // Determines the starting point for validating that TIDx follows the Group + // ID. Two cases: + // 1. If no splits were seen (last_split_seen_ == nullptr): + // Group ID comes directly from logical domain or via merges only. + // Start from the logical ID just before those feeding into Group ID. + // 2. Otherwise: + // Start from the outer output of the last split seen. + IterDomain* determineThreadXTraversalStart() const { + if (last_split_seen_ == nullptr) { + // Case 1: Group ID derived directly from logical domain + auto ids_in_logical = IterVisitor::getInputsTo({group_id_}); + + // Validate all IDs feeding into Group ID have constant extents + for (auto id : ids_in_logical) { + auto iter_domain = id->as(); + NVF_ERROR( + iter_domain->extent()->isConstInt(), + "Expected all IDs feeding directly into Group ID to have constant " + "extents for BlockQuantizationOp: ", + tv_->toString()); + } + + // Ensure there are logical IDs left to derive thread IDs + const auto& logical_domain = tv_->getLogicalDomain(); + NVF_ERROR( + ids_in_logical.size() < logical_domain.size(), + "There aren't enough logical IDs to derive thread IDs: ", + tv_->toString()); + + // Return the logical ID just before the ones feeding into Group ID + return logical_domain[logical_domain.size() - ids_in_logical.size() - 1]; + } else { + // Case 2: Start from the outer output of the last split + return last_split_seen_->outer(); + } + } + + // Validates that TIDx follows the Group ID in the schedule using DFS + // traversal. Starting from the given ID, traverses through splits and merges + // to ensure TIDx is reachable. Any terminating IDs that are not TIDx must + // have extent 1. + void validateThreadXFollowsGroupId(IterDomain* start_id) const { + std::stack to_visit; + to_visit.push(start_id); + + while (!to_visit.empty()) { + auto current_id = to_visit.top(); + to_visit.pop(); + + // Check terminating IDs (no uses) + if (current_id->uses().empty()) { + if (current_id->getParallelType() == ParallelType::TIDx) { + return; // Found TIDx - validation successful + } + + // Non-TIDx terminating IDs must have constant extent of 1 + NVF_ERROR( + current_id->extent()->isConstInt(), + "Only constant extent IDs are expected between TIDx and Group ID " + "in BlockQuantizationOp quantized output: ", + tv_->toInlineString()); + NVF_ERROR( + current_id->extent()->evaluate().as() == 1, + "Only constant extent IDs with extent of 1 are expected between " + "TIDx and Group ID in BlockQuantizationOp quantized output: ", + tv_->toInlineString()); + continue; + } + + // Validate single use (no branching in the path to TIDx) + NVF_ERROR( + current_id->uses().size() == 1, + "Expected single use for IDs in logical to loop transforms for " + "BlockQuantizationOp quantization output: ", + current_id->toString()); + + // Process the use expression + auto use_expr = current_id->uses().at(0); + if (auto merge = dynamic_cast(use_expr)) { + NVF_ERROR( + isAcceptableMerge(merge), + "Invalid merge found while tracing forward in the logical to loop " + "transforms for BlockQuantizationOp quantization output: ", + tv_->toString()); + to_visit.push(merge->out()); + } else if (auto split = dynamic_cast(use_expr)) { + // DFS: inner split first, then outer split + to_visit.push(split->outer()); + to_visit.push(split->inner()); + } else { + NVF_ERROR( + false, + "Unexpected use of an ID found while tracing forward in the " + "logical to loop transforms for BlockQuantizationOp quantization " + "output: ", + tv_->toString()); + } + } + } + + private: + const TensorView* tv_; + IterDomain* group_id_; + Split* last_split_seen_ = nullptr; +}; + // Expr-specific validaion // // TODO: Move individual validations to here, e.g., @@ -411,6 +651,133 @@ class ExprValidator : public OptOutDispatch { (*it)->toString()); } } + + // Basic tests: + // Input is in local memory. + // Block scaling factor is in global memory and + // quantized output is in local memory. + // Any loop ID that is not TID(x/y). BID(x/y) or Group + // has an extent of 1. + // The Group ID has an extent of 4/8 depending on the data type. + // The following are more complex checks that look at the schedule. + // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. + // Our aims for the following checks are to ensure that the Group ID is + // contiguous and unit stride, and then after Group ID, we have TIDx. + // Such that (G) * ThreadIdx.x + GID is contiguous. + // Checks for Group ID: + // Next we check that the Group ID is contiguous and unit stride. + // We walk up the logical domain to loop domains path starting from the Group + // ID (G). + void handle(BlockQuantizationOp* bqop) final { + auto inp_tv = bqop->input(0)->as(); + auto quantized_output = bqop->quantizedOutput()->as(); + auto block_scaling_factor = bqop->blockScales()->as(); + + NVF_ERROR_EQ( + inp_tv->getMemoryType(), + MemoryType::Local, + "Input must be a local memory tensor. Found: ", + inp_tv->getMemoryType()); + + NVF_ERROR_EQ( + quantized_output->getMemoryType(), + MemoryType::Local, + "Quantized output must be a local memory tensor. Found: ", + quantized_output->getMemoryType()); + + NVF_ERROR_EQ( + block_scaling_factor->getMemoryType(), + MemoryType::Global, + "Block scaling factor must be a global memory tensor. Found: ", + block_scaling_factor->getMemoryType()); + + // outputs have the same allocation domain + // as the loop domain. This has to be later + // relaxed for the scaling factors. + NVF_ERROR( + quantized_output->hasAllocation() == false, + "Quantized output must not have an allocation domain."); + NVF_ERROR( + block_scaling_factor->hasAllocation() == false, + "Block scaling factor must not have an allocation domain."); + + // Check that it either had vectorized ID or grouped ID + // not both and the extent is either 4(FP32) or 8(BF16) + IterDomain* grouped_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + grouped_id = loop_id; + } else if (loop_id->getParallelType() == ParallelType::Vectorize) { + NVF_ERROR(false, "Cannot have vectorized ID in BlockQuantizationOp"); + } else if (loop_id->getParallelType() == ParallelType::TIDx) { + thread_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDx) { + block_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDz) { + thread_z = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDz) { + block_z = loop_id; + } else if ( + loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial/Unswitch/Unroll ID in " + "BlockQuantizationOp"); + NVF_ERROR( + loop_id->extent()->evaluate().as() == 1, + "Expected non-TID/BID/Group ID to have extent of 1 for " + "BlockQuantizationOp: ", + bqop->toString()); + } + } + + NVF_ERROR( + grouped_id != nullptr, + "One of the output IDs must be grouped for " + "BlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using BlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for " + "BlockQuantizationOp: ", + bqop->toString()); + + auto inner_extent = grouped_id->extent()->evaluate().as(); + auto input_dtype = inp_tv->dtype(); + + NVF_ERROR( + (inner_extent == 4 && input_dtype == DataType::Float) || + (inner_extent == 8 && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The vectorized/grouped dimension must be 4 (FP32) or 8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + bqop->toString()); + + NVF_ERROR( + grouped_id != nullptr, + "Expected a valid loop grouped ID for BlockQuantizationOp: ", + bqop->toString()); + + BlockQuantizationValidationHelper helper(quantized_output, grouped_id); + helper.run(); + } }; } // namespace diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 260a8e0e3f4..0d1ae5f3acd 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -405,6 +405,322 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { EXPECT_EQ(quantized_tensor_output.dim(), 2); } +class BlockQuantizationValidationTest : public BlackwellBase { + protected: + // Helper function to create test input tensor + at::Tensor createTestInput(int64_t dim = 2) { + if (dim == 2) { + return at::randn({1024, 1024}, at::device(at::kCUDA).dtype(at::kFloat)); + } else if (dim == 3) { + return at::randn({16, 64, 1024}, at::device(at::kCUDA).dtype(at::kFloat)); + } else { + throw std::runtime_error("Unsupported dimension for createTestInput"); + } + } + + // Helper function to assert compilation fails with expected error message + void assertCompilationFails( + Fusion* fusion, + const std::vector& inputs, + const char* expected_error_msg) { + KernelExecutor ke; + try { + ke.compile(fusion, inputs); + FAIL() << "Expected compilation to throw error: " << expected_error_msg; + } catch (const std::exception& e) { + ASSERT_TRUE(strstr(e.what(), expected_error_msg) != nullptr) + << "Expected error message containing: \"" << expected_error_msg + << "\"\nActual error: " << e.what(); + } + } + + // Helper to create a fusion with blockQuantize and apply scheduling + struct FusionSetup { + std::unique_ptr fusion; + TensorView* tv_data_hp; + TensorView* t0; + TensorView* quantized_tensor; + TensorView* block_scales; + TensorView* t_out; + }; + + FusionSetup createBlockQuantizeFusion(int64_t dim = 2) { + FusionSetup setup; + setup.fusion = std::make_unique(); + FusionGuard fg(setup.fusion.get()); + + setup.tv_data_hp = makeContigTensor(dim, DataType::Float); + setup.fusion->addInput(setup.tv_data_hp); + + setup.t0 = set(setup.tv_data_hp); + auto quantization_results = blockQuantize(setup.t0); + setup.quantized_tensor = quantization_results.quantized_tensor; + setup.block_scales = quantization_results.block_scales; + setup.t_out = set(setup.quantized_tensor); + + setup.fusion->addOutput(setup.block_scales); + setup.fusion->addOutput(setup.t_out); + + return setup; + } + + // Helper to apply common merge and split operations + void applyMergeAndSplit( + TensorView* t, + int64_t split_factor, + int64_t inner_split = 1, + int64_t thread_split = 128) { + // Merge all dims + t->merge(-2); + if (t->getLoopDomain().size() >= 2) { + t->merge(-2); + } + + // Apply splits: I -> I/split_factor, split_factor + t->split(-1, split_factor); + // I/split_factor, split_factor -> I/split_factor, inner_split, + // split_factor/inner_split + t->split(-2, inner_split); + // I/split_factor, inner_split, split_factor/inner_split -> I/thread_split, + // thread_split/split_factor, inner_split, split_factor/inner_split + t->split(-3, thread_split); + } +}; + +TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion->addInput(tv_data_hp); + + // Don't set memory type - remains global (default for inputs) + auto quantization_results = blockQuantize(tv_data_hp); + auto t_out = set(quantization_results.quantized_tensor); + + fusion->addOutput(quantization_results.block_scales); + fusion->addOutput(t_out); + + assertCompilationFails( + fusion.get(), {createTestInput()}, "Input must be a local memory tensor"); +} + +TEST_F(BlockQuantizationValidationTest, QuantizedOutputMustBeInLocalMemory) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion->addInput(tv_data_hp); + + tv_data_hp = set(tv_data_hp); + auto quantization_results = blockQuantize(tv_data_hp); + + fusion->addOutput(quantization_results.block_scales); + fusion->addOutput(quantization_results.quantized_tensor); + + assertCompilationFails( + fusion.get(), + {createTestInput()}, + "Quantized output must be a local memory tensor"); +} + +TEST_F( + BlockQuantizationValidationTest, + BlockScalingFactorMustBeInGlobalMemory) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion->addInput(tv_data_hp); + + tv_data_hp = set(tv_data_hp); + auto quantization_results = blockQuantize(tv_data_hp); + auto tv_block_scales = set(quantization_results.block_scales); + auto tv_quantized_out = set(quantization_results.quantized_tensor); + + fusion->addOutput(tv_block_scales); + fusion->addOutput(tv_quantized_out); + + assertCompilationFails( + fusion.get(), + {createTestInput()}, + "Block scaling factor must be a global memory tensor"); +} + +TEST_F( + BlockQuantizationValidationTest, + QuantizedOutputCannotHaveVectorizedDimension) { + auto setup = createBlockQuantizeFusion(); + FusionGuard fg(setup.fusion.get()); + + std::vector tensors = { + setup.tv_data_hp, + setup.t0, + setup.quantized_tensor, + setup.block_scales, + setup.t_out}; + + for (auto t : tensors) { + applyMergeAndSplit(t, /*split_factor=*/4); + + // Vectorize all non-input tensors (this should fail) + // as quantized output cannot be vectorized + if (t != setup.tv_data_hp) { + t->axis(-1)->parallelize(ParallelType::Vectorize); + t->axis(-3)->parallelize(ParallelType::TIDx); + t->axis(-4)->parallelize(ParallelType::BIDx); + } + } + + assertCompilationFails( + setup.fusion.get(), + {createTestInput()}, + "Cannot have vectorized ID in BlockQuantizationOp"); +} + +TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { + auto setup = createBlockQuantizeFusion(); + FusionGuard fg(setup.fusion.get()); + + std::vector tensors = { + setup.tv_data_hp, + setup.t0, + setup.quantized_tensor, + setup.block_scales, + setup.t_out}; + + for (auto t : tensors) { + applyMergeAndSplit( + t, /*split_factor=*/128, /*inner_split=*/1, /*thread_split=*/4); + + if (t != setup.tv_data_hp) { + // Mark outer ID as Group for quantized outputs (should fail) + // instead of innermost ID + if (t == setup.block_scales || t == setup.quantized_tensor) { + t->axis(-3)->parallelize(ParallelType::Group); + t->axis(-1)->parallelize(ParallelType::TIDx); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + t->axis(-3)->parallelize(ParallelType::TIDx); + } + t->axis(-4)->parallelize(ParallelType::BIDx); + } + } + + assertCompilationFails( + setup.fusion.get(), + {createTestInput()}, + "The grouped ID must correspond to the innermost of all splits from " + "logical domains to loop domains for BlockQuantizationOp"); +} + +TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { + auto setup = createBlockQuantizeFusion(); + FusionGuard fg(setup.fusion.get()); + + std::vector tensors = { + setup.tv_data_hp, + setup.t0, + setup.quantized_tensor, + setup.block_scales, + setup.t_out}; + + for (auto t : tensors) { + applyMergeAndSplit(t, /*split_factor=*/4, /*inner_split=*/2); + + if (t != setup.tv_data_hp) { + if (t == setup.block_scales || t == setup.quantized_tensor) { + t->axis(-1)->parallelize(ParallelType::Group); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } + t->axis(-3)->parallelize(ParallelType::TIDx); + t->axis(-4)->parallelize(ParallelType::BIDx); + } + } + + assertCompilationFails( + setup.fusion.get(), + {createTestInput()}, + "Expected non-TID/BID/Group ID to have extent of 1 for " + "BlockQuantizationOp"); +} + +TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { + auto setup = createBlockQuantizeFusion(); + FusionGuard fg(setup.fusion.get()); + + std::vector tensors = { + setup.tv_data_hp, + setup.t0, + setup.quantized_tensor, + setup.block_scales, + setup.t_out}; + + for (auto t : tensors) { + applyMergeAndSplit(t, /*split_factor=*/4); + + if (t != setup.tv_data_hp) { + if (t == setup.block_scales || t == setup.quantized_tensor) { + t->axis(-1)->parallelize(ParallelType::Group); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } + t->axis(-3)->parallelize(ParallelType::BIDx); + t->axis(-4)->parallelize(ParallelType::TIDx); + } + } + + assertCompilationFails( + setup.fusion.get(), + {createTestInput()}, + "Only constant extent IDs with extent of 1 are expected between TIDx " + "and Group ID in BlockQuantizationOp quantized output"); +} + +TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { + auto setup = createBlockQuantizeFusion(/*dim=*/3); + FusionGuard fg(setup.fusion.get()); + + std::vector tensors = { + setup.tv_data_hp, + setup.t0, + setup.quantized_tensor, + setup.block_scales, + setup.t_out}; + + for (auto t : tensors) { + // Merge first two dims instead of last two + t->reorder({{0, 1}, {1, 0}}); // (i0, i1, i2) -> (i1, i0, i2) + t->merge(1); + + // split I1 by 4 + t->split(-1, 4); + // I/4, 4 -> I/4, 1, 4 + t->split(-2, 1); + // I/4, 1, 4 -> I/512, 128, 1, 4 + t->split(-3, 128); + + if (t != setup.tv_data_hp) { + if (t == setup.block_scales || t == setup.quantized_tensor) { + t->axis(-1)->parallelize(ParallelType::Group); + } else { + t->axis(-1)->parallelize(ParallelType::Vectorize); + } + t->axis(-3)->parallelize(ParallelType::TIDx); + t->axis(-4)->parallelize(ParallelType::BIDx); + t->axis(-5)->parallelize(ParallelType::BIDy); + } + } + + assertCompilationFails( + setup.fusion.get(), + {createTestInput(/*dim=*/3)}, + "Invalid merge found while tracing back the grouped ID for " + "BlockQuantizationOp. All inputs to merge must be from logical domain " + "or be outputs of other merges"); +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); From 9ac82614b44372ae84bd54c3c1499093c993f448 Mon Sep 17 00:00:00 2001 From: protonu Date: Sun, 2 Nov 2025 05:55:40 -0800 Subject: [PATCH 44/79] merge --- tests/cpp/test_low_precision_recipe.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 653ec2ecf72..0d1ae5f3acd 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -405,7 +405,6 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { EXPECT_EQ(quantized_tensor_output.dim(), 2); } -<<<<<<< HEAD class BlockQuantizationValidationTest : public BlackwellBase { protected: // Helper function to create test input tensor @@ -722,8 +721,6 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { "or be outputs of other merges"); } -======= ->>>>>>> main TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); From a1612c8889644ecc0dee3dd7444a1b908000b003 Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 3 Nov 2025 03:57:08 -0800 Subject: [PATCH 45/79] adding comments to tests --- tests/cpp/test_low_precision_recipe.cpp | 40 +++++++++++++++++++------ 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 0d1ae5f3acd..0e275d85a31 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -434,7 +434,7 @@ class BlockQuantizationValidationTest : public BlackwellBase { } } - // Helper to create a fusion with blockQuantize and apply scheduling + // Helper function to create a fusion with blockQuantize and apply scheduling struct FusionSetup { std::unique_ptr fusion; TensorView* tv_data_hp; @@ -465,28 +465,29 @@ class BlockQuantizationValidationTest : public BlackwellBase { } // Helper to apply common merge and split operations + // This is limited for the tests with 2D tv inputs. void applyMergeAndSplit( TensorView* t, int64_t split_factor, int64_t inner_split = 1, int64_t thread_split = 128) { // Merge all dims + // (I0, I1) -> (I0*I1) == (I) t->merge(-2); - if (t->getLoopDomain().size() >= 2) { - t->merge(-2); - } // Apply splits: I -> I/split_factor, split_factor t->split(-1, split_factor); - // I/split_factor, split_factor -> I/split_factor, inner_split, - // split_factor/inner_split + // I/split_factor, split_factor -> I/split_factor/inner_split, inner_split, + // I/split_factor t->split(-2, inner_split); - // I/split_factor, inner_split, split_factor/inner_split -> I/thread_split, - // thread_split/split_factor, inner_split, split_factor/inner_split + // I/split_factor/inner_split, inner_split, I/split_factor -> + // I/split_factor/inner_split/thread_split, thread_split, inner_split, + // I/split_factor t->split(-3, thread_split); } }; +// Input is in global memory - not valid TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -494,7 +495,6 @@ TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { auto tv_data_hp = makeContigTensor(2, DataType::Float); fusion->addInput(tv_data_hp); - // Don't set memory type - remains global (default for inputs) auto quantization_results = blockQuantize(tv_data_hp); auto t_out = set(quantization_results.quantized_tensor); @@ -505,6 +505,7 @@ TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { fusion.get(), {createTestInput()}, "Input must be a local memory tensor"); } +// Quantized output is written to global memory - not valid TEST_F(BlockQuantizationValidationTest, QuantizedOutputMustBeInLocalMemory) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -524,6 +525,7 @@ TEST_F(BlockQuantizationValidationTest, QuantizedOutputMustBeInLocalMemory) { "Quantized output must be a local memory tensor"); } +// Block scaling factor is written to local memory - not valid TEST_F( BlockQuantizationValidationTest, BlockScalingFactorMustBeInGlobalMemory) { @@ -547,6 +549,8 @@ TEST_F( "Block scaling factor must be a global memory tensor"); } +// Quantized output when scheduled cannot have a vectorized dimension +// but should have a group dim - this is not valid. TEST_F( BlockQuantizationValidationTest, QuantizedOutputCannotHaveVectorizedDimension) { @@ -578,6 +582,8 @@ TEST_F( "Cannot have vectorized ID in BlockQuantizationOp"); } +// Group ID must be the innermost of all splits from logical domains to loop +// domains TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { auto setup = createBlockQuantizeFusion(); FusionGuard fg(setup.fusion.get()); @@ -614,6 +620,9 @@ TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { "logical domains to loop domains for BlockQuantizationOp"); } +// We do not allow IDs of types serial, unroll, unswitch to have extent > 1 +// We do not want the runtime kernel which implement block quantization to be +// called multiple times in a kernel as yet TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { auto setup = createBlockQuantizeFusion(); FusionGuard fg(setup.fusion.get()); @@ -626,6 +635,7 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { setup.t_out}; for (auto t : tensors) { + // There will be a non-parallelized ID with a trip count of 2 applyMergeAndSplit(t, /*split_factor=*/4, /*inner_split=*/2); if (t != setup.tv_data_hp) { @@ -646,6 +656,12 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { "BlockQuantizationOp"); } +// The runtime kernel for block quantization expects TIDx to access contiguous +// memory locations - just 16, but to be safe we enfore all memory locations of +// TIDx are contiguous. To enfore this, TIDx must be the second innermost ID +// after Group ID. By that we mean if we derive this ID from the logical domain, +// there should be no other IDs between Group ID and TIDx except for IDs with +// extent of 1. TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { auto setup = createBlockQuantizeFusion(); FusionGuard fg(setup.fusion.get()); @@ -666,6 +682,7 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { } else { t->axis(-1)->parallelize(ParallelType::Vectorize); } + // TIDx is "outer" compared to BIDx causing a failure t->axis(-3)->parallelize(ParallelType::BIDx); t->axis(-4)->parallelize(ParallelType::TIDx); } @@ -678,6 +695,10 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { "and Group ID in BlockQuantizationOp quantized output"); } +// When running validation checks we traverse from loop to logical domain +// and vice-versa. During this traversal, when we encounter a merge operation, +// we find all input IDs to the merge (traced back to the logical domain of the +// quantized output). The input IDs in the logical domain need to be contiguous. TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { auto setup = createBlockQuantizeFusion(/*dim=*/3); FusionGuard fg(setup.fusion.get()); @@ -691,6 +712,7 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { for (auto t : tensors) { // Merge first two dims instead of last two + // This will cause a failure as the merged IDs are not contiguous t->reorder({{0, 1}, {1, 0}}); // (i0, i1, i2) -> (i1, i0, i2) t->merge(1); From fb6bd53633847f03f0b0c71f913654b97977a3fd Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 3 Nov 2025 05:50:48 -0800 Subject: [PATCH 46/79] more comments for validation --- csrc/device_lower/validation.cpp | 55 +++++++++++++++++++------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index be09b0e9883..636d77c25fc 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -262,20 +262,25 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { } // Helper class for BlockQuantization validation -// -// This class validates the scheduling requirements for BlockQuantizationOp: +// This class validates the following scheduling requirements for +// BlockQuantizationOp: // 1. The Group ID must be derived from the innermost logical IDs // 2. Merge operations must combine contiguous logical IDs -// 3. TIDx must follow the Group ID in the schedule +// 3. TIDx must follow the Group ID in the schedule -- that is when derived from +// the logical domain, group ID must be inner-most, the next "inner-most" should +// be TIDx (unless there is an ID with a unit trip count) class BlockQuantizationValidationHelper { public: BlockQuantizationValidationHelper(const TensorView* tv, IterDomain* group_id) : tv_(tv), group_id_(group_id) {} - // Validates the complete scheduling structure for block quantization void run() { + // Trace back from Group ID to logical domain to see group is "inner-most" traceGroupIdToLogicalDomain(); + // Find where to start the search for TIDx IterDomain* restart_traversal_from = determineThreadXTraversalStart(); + // We start a DFS for TIDx taking the inner outputs of Split nodes first. + // We should not find a terminating ID that is not TIDx and has extent > 1. validateThreadXFollowsGroupId(restart_traversal_from); } @@ -347,7 +352,7 @@ class BlockQuantizationValidationHelper { // innermost) // - For merges: Traverse through the inner input and validate contiguity // - // Returns via last_split_seen_: The last split encountered during traversal, + // Stores last_split_seen_: The last split encountered during traversal, // which is used as the starting point for validating the TIDx path. void traceGroupIdToLogicalDomain() { NVF_ERROR( @@ -401,7 +406,8 @@ class BlockQuantizationValidationHelper { // Group ID comes directly from logical domain or via merges only. // Start from the logical ID just before those feeding into Group ID. // 2. Otherwise: - // Start from the outer output of the last split seen. + // Start from the outer output of the last split seen - as we must come up + // to the split from the group ID via the inner output. IterDomain* determineThreadXTraversalStart() const { if (last_split_seen_ == nullptr) { // Case 1: Group ID derived directly from logical domain @@ -432,10 +438,10 @@ class BlockQuantizationValidationHelper { } } - // Validates that TIDx follows the Group ID in the schedule using DFS - // traversal. Starting from the given ID, traverses through splits and merges - // to ensure TIDx is reachable. Any terminating IDs that are not TIDx must - // have extent 1. + // Validates that TIDx follows the Group ID in the schedule using DFS. + // Starting from the given ID, traverses through splits and merges + // to ensure TIDx is reachable. We traverse through the split by taking the + // inner path first. Any terminating IDs that are not TIDx must have extent 1. void validateThreadXFollowsGroupId(IterDomain* start_id) const { std::stack to_visit; to_visit.push(start_id); @@ -464,7 +470,6 @@ class BlockQuantizationValidationHelper { continue; } - // Validate single use (no branching in the path to TIDx) NVF_ERROR( current_id->uses().size() == 1, "Expected single use for IDs in logical to loop transforms for " @@ -652,22 +657,24 @@ class ExprValidator : public OptOutDispatch { } } - // Basic tests: + // Basic checks: // Input is in local memory. // Block scaling factor is in global memory and // quantized output is in local memory. // Any loop ID that is not TID(x/y). BID(x/y) or Group // has an extent of 1. // The Group ID has an extent of 4/8 depending on the data type. - // The following are more complex checks that look at the schedule. // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. - // Our aims for the following checks are to ensure that the Group ID is - // contiguous and unit stride, and then after Group ID, we have TIDx. - // Such that (G) * ThreadIdx.x + GID is contiguous. - // Checks for Group ID: - // Next we check that the Group ID is contiguous and unit stride. - // We walk up the logical domain to loop domains path starting from the Group - // ID (G). + // The following are more complex checks that look at the schedule. + // These checks are implemented using the helper class + // BlockQuantizationValidationHelper. + // Our aims for the following checks are to ensure that the group ID is + // contiguous and unit stride, and then after group ID, we have TIDx. Such + // that (G -- extent of GID) * ThreadIdx.x + GID is contiguous. + // We do so by checking that the group ID unit stride. It is derived from the + // innermost logical IDs via merges and inner splits only. Next we check that + // TIDx in the next inner-most ID, and if there was any other ID between TIDx + // and group ID then it must have an extent of 1. void handle(BlockQuantizationOp* bqop) final { auto inp_tv = bqop->input(0)->as(); auto quantized_output = bqop->quantizedOutput()->as(); @@ -691,12 +698,13 @@ class ExprValidator : public OptOutDispatch { "Block scaling factor must be a global memory tensor. Found: ", block_scaling_factor->getMemoryType()); - // outputs have the same allocation domain - // as the loop domain. This has to be later - // relaxed for the scaling factors. + // Outputs have the same allocation domain + // as the logical domain - no allocation domain. NVF_ERROR( quantized_output->hasAllocation() == false, "Quantized output must not have an allocation domain."); + + // TODO: Relax this for swizzled block scaling factor outputs NVF_ERROR( block_scaling_factor->hasAllocation() == false, "Block scaling factor must not have an allocation domain."); @@ -775,6 +783,7 @@ class ExprValidator : public OptOutDispatch { "Expected a valid loop grouped ID for BlockQuantizationOp: ", bqop->toString()); + // Helper to check to the most involved scheduling requirements. BlockQuantizationValidationHelper helper(quantized_output, grouped_id); helper.run(); } From fe787e0d9a006ef7f3c1374448a091b48ea23d99 Mon Sep 17 00:00:00 2001 From: Protonu Date: Mon, 3 Nov 2025 08:58:02 -0500 Subject: [PATCH 47/79] Update tests/cpp/test_low_precision_recipe.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- tests/cpp/test_low_precision_recipe.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 0e275d85a31..46d9e347d4a 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -657,7 +657,7 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { } // The runtime kernel for block quantization expects TIDx to access contiguous -// memory locations - just 16, but to be safe we enfore all memory locations of +// memory locations - just 16, but to be safe we enforce all memory locations of // TIDx are contiguous. To enfore this, TIDx must be the second innermost ID // after Group ID. By that we mean if we derive this ID from the logical domain, // there should be no other IDs between Group ID and TIDx except for IDs with From 5b11a746a4238fb69278798e89d7bbe48fbbaeea Mon Sep 17 00:00:00 2001 From: Protonu Date: Mon, 3 Nov 2025 08:58:18 -0500 Subject: [PATCH 48/79] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/device_lower/validation.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 636d77c25fc..bb8f45ffcba 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -778,10 +778,9 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); - NVF_ERROR( - grouped_id != nullptr, - "Expected a valid loop grouped ID for BlockQuantizationOp: ", - bqop->toString()); + // Helper to check to the most involved scheduling requirements. + BlockQuantizationValidationHelper helper(quantized_output, grouped_id); + helper.run(); // Helper to check to the most involved scheduling requirements. BlockQuantizationValidationHelper helper(quantized_output, grouped_id); From 4ea7bae8659aff61e9c5174158c55aff4581cb80 Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 3 Nov 2025 06:02:25 -0800 Subject: [PATCH 49/79] fix weird duplicated code that showed up --- csrc/device_lower/validation.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index bb8f45ffcba..af835e671d5 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -781,10 +781,6 @@ class ExprValidator : public OptOutDispatch { // Helper to check to the most involved scheduling requirements. BlockQuantizationValidationHelper helper(quantized_output, grouped_id); helper.run(); - - // Helper to check to the most involved scheduling requirements. - BlockQuantizationValidationHelper helper(quantized_output, grouped_id); - helper.run(); } }; From bb21123739f630f7ab7d2a06a5abcb0ed9b6e2c7 Mon Sep 17 00:00:00 2001 From: protonu Date: Mon, 3 Nov 2025 07:03:31 -0800 Subject: [PATCH 50/79] removing stale comment --- csrc/device_lower/validation.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index af835e671d5..dba638716de 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -709,8 +709,6 @@ class ExprValidator : public OptOutDispatch { block_scaling_factor->hasAllocation() == false, "Block scaling factor must not have an allocation domain."); - // Check that it either had vectorized ID or grouped ID - // not both and the extent is either 4(FP32) or 8(BF16) IterDomain* grouped_id = nullptr; IterDomain* thread_x = nullptr; IterDomain* block_x = nullptr; From 918bd94229e0899bdc506b76fa77d584e1010471 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 05:01:34 -0800 Subject: [PATCH 51/79] changes to tests based on reviewer comments --- tests/cpp/test_low_precision_recipe.cpp | 53 ++++++------------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 46d9e347d4a..00522352c05 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -424,14 +424,18 @@ class BlockQuantizationValidationTest : public BlackwellBase { const std::vector& inputs, const char* expected_error_msg) { KernelExecutor ke; - try { - ke.compile(fusion, inputs); - FAIL() << "Expected compilation to throw error: " << expected_error_msg; - } catch (const std::exception& e) { - ASSERT_TRUE(strstr(e.what(), expected_error_msg) != nullptr) - << "Expected error message containing: \"" << expected_error_msg - << "\"\nActual error: " << e.what(); - } + EXPECT_THROW( + { + try { + ke.compile(fusion, inputs); + } catch (const std::exception& e) { + EXPECT_THAT(e.what(), ::testing::HasSubstr(expected_error_msg)) + << "Expected error message containing: \"" << expected_error_msg + << "\"\nActual error: " << e.what(); + throw; // Re-throw for EXPECT_THROW to catch + } + }, + std::exception); } // Helper function to create a fusion with blockQuantize and apply scheduling @@ -549,39 +553,6 @@ TEST_F( "Block scaling factor must be a global memory tensor"); } -// Quantized output when scheduled cannot have a vectorized dimension -// but should have a group dim - this is not valid. -TEST_F( - BlockQuantizationValidationTest, - QuantizedOutputCannotHaveVectorizedDimension) { - auto setup = createBlockQuantizeFusion(); - FusionGuard fg(setup.fusion.get()); - - std::vector tensors = { - setup.tv_data_hp, - setup.t0, - setup.quantized_tensor, - setup.block_scales, - setup.t_out}; - - for (auto t : tensors) { - applyMergeAndSplit(t, /*split_factor=*/4); - - // Vectorize all non-input tensors (this should fail) - // as quantized output cannot be vectorized - if (t != setup.tv_data_hp) { - t->axis(-1)->parallelize(ParallelType::Vectorize); - t->axis(-3)->parallelize(ParallelType::TIDx); - t->axis(-4)->parallelize(ParallelType::BIDx); - } - } - - assertCompilationFails( - setup.fusion.get(), - {createTestInput()}, - "Cannot have vectorized ID in BlockQuantizationOp"); -} - // Group ID must be the innermost of all splits from logical domains to loop // domains TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { From 82990c4315c399e09d23e8d7613b4e4d6afaf0f5 Mon Sep 17 00:00:00 2001 From: Protonu Date: Tue, 4 Nov 2025 08:39:09 -0500 Subject: [PATCH 52/79] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- tests/cpp/test_low_precision_recipe.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 00522352c05..65bc822a3c5 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -629,7 +629,7 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { // The runtime kernel for block quantization expects TIDx to access contiguous // memory locations - just 16, but to be safe we enforce all memory locations of -// TIDx are contiguous. To enfore this, TIDx must be the second innermost ID +// TIDx are contiguous. To enforce this, TIDx must be the second innermost ID // after Group ID. By that we mean if we derive this ID from the logical domain, // there should be no other IDs between Group ID and TIDx except for IDs with // extent of 1. From b87aa9a8f671085cb97a95ebb62bc0b95df579ad Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 06:22:08 -0800 Subject: [PATCH 53/79] remove vectorization testt and address less involved reviewer comments --- csrc/device_lower/validation.cpp | 15 +++++++-------- tests/cpp/test_low_precision_recipe.cpp | 6 +++--- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index dba638716de..770b0f4788f 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -384,7 +384,7 @@ class BlockQuantizationValidationHelper { last_split_seen_ = split; current_id = split->in(); } else { - NVF_ERROR( + NVF_THROW( false, "Unexpected definition found while tracing back the grouped ID for " "BlockQuantizationOp: ", @@ -501,7 +501,7 @@ class BlockQuantizationValidationHelper { } private: - const TensorView* tv_; + const TensorView* tv_ = nullptr; IterDomain* group_id_; Split* last_split_seen_ = nullptr; }; @@ -701,12 +701,12 @@ class ExprValidator : public OptOutDispatch { // Outputs have the same allocation domain // as the logical domain - no allocation domain. NVF_ERROR( - quantized_output->hasAllocation() == false, + !quantized_output->hasAllocation(), "Quantized output must not have an allocation domain."); // TODO: Relax this for swizzled block scaling factor outputs NVF_ERROR( - block_scaling_factor->hasAllocation() == false, + !block_scaling_factor->hasAllocation(), "Block scaling factor must not have an allocation domain."); IterDomain* grouped_id = nullptr; @@ -718,8 +718,6 @@ class ExprValidator : public OptOutDispatch { for (const auto& loop_id : quantized_output->getLoopDomain()) { if (loop_id->getParallelType() == ParallelType::Group) { grouped_id = loop_id; - } else if (loop_id->getParallelType() == ParallelType::Vectorize) { - NVF_ERROR(false, "Cannot have vectorized ID in BlockQuantizationOp"); } else if (loop_id->getParallelType() == ParallelType::TIDx) { thread_x = loop_id; } else if (loop_id->getParallelType() == ParallelType::BIDx) { @@ -737,8 +735,9 @@ class ExprValidator : public OptOutDispatch { loop_id->extent()->isConstInt(), "Expected constant extent for Serial/Unswitch/Unroll ID in " "BlockQuantizationOp"); - NVF_ERROR( - loop_id->extent()->evaluate().as() == 1, + NVF_ERROR_EQ( + loop_id->extent()->evaluate().as(), + 1, "Expected non-TID/BID/Group ID to have extent of 1 for " "BlockQuantizationOp: ", bqop->toString()); diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 65bc822a3c5..3042bec8e00 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -482,11 +482,11 @@ class BlockQuantizationValidationTest : public BlackwellBase { // Apply splits: I -> I/split_factor, split_factor t->split(-1, split_factor); // I/split_factor, split_factor -> I/split_factor/inner_split, inner_split, - // I/split_factor + // split_factor t->split(-2, inner_split); - // I/split_factor/inner_split, inner_split, I/split_factor -> + // I/split_factor/inner_split, inner_split, split_factor -> // I/split_factor/inner_split/thread_split, thread_split, inner_split, - // I/split_factor + // split_factor t->split(-3, thread_split); } }; From 8daae2fd509a9cd82db38234825f128fb1eafe06 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 11:15:45 -0800 Subject: [PATCH 54/79] adding new validation --- csrc/device_lower/validation.cpp | 387 ++++++++---------------- csrc/scheduler/utils.cpp | 11 +- csrc/scheduler/utils.h | 6 + tests/cpp/test_low_precision_recipe.cpp | 9 +- 4 files changed, 145 insertions(+), 268 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 770b0f4788f..c6e28f1c7c2 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -207,17 +208,15 @@ void validateCpAsyncBulk(const std::vector& tvs) { } } -// Check if maybe_innermost_id is derived from base_id and corresponds to the -// innermost subregion of base_id. The split/merge exprs between -// based_id and id must not include any ID that is not produced from -// base_id. -bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { - auto exprs = - DependencyCheck::getAllExprsBetween({base_id}, {maybe_innermost_id}); - - std::deque frontier; - frontier.push_back(base_id); - +// Traverse through the expressions, updating the frontier based on merge and +// split operations. Returns true if all merges encountered are contiguous. +// If stop_on_noncontiguous is true, stops traversal and returns false on first +// non-contiguous merge. Otherwise, removes non-contiguous merges from frontier +// and continues. +bool traverseFrontierWithContiguityCheck( + std::deque& frontier, + const std::vector& exprs, + bool stop_on_noncontiguous) { for (auto expr : exprs) { // expr is skipped if any of the inputs is missing. if (auto merge = dynamic_cast(expr)) { @@ -235,6 +234,12 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { auto inner_pos = std::distance(frontier.begin(), inner_it); bool is_contig = outer_pos + 1 == inner_pos; + + if (!is_contig && stop_on_noncontiguous) { + // Found a non-contiguous merge + return false; + } + frontier.erase(inner_it); // If it's contig, we can continue the analysis by proceeding to @@ -254,6 +259,23 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { *in_it = split->outer(); } } + return true; +} + +// Check if maybe_innermost_id is derived from base_id and corresponds to the +// innermost subregion of base_id. The split/merge exprs between +// based_id and id must not include any ID that is not produced from +// base_id. +bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { + auto exprs = + DependencyCheck::getAllExprsBetween({base_id}, {maybe_innermost_id}); + + std::deque frontier; + frontier.push_back(base_id); + + // Don't stop on non-contiguous merges; remove them from frontier and continue + traverseFrontierWithContiguityCheck( + frontier, exprs, /*stop_on_noncontiguous=*/false); // Once the traversal is done, if the target id located at the // rightmost position of the frontier list, it is guaranteed to @@ -261,250 +283,20 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { return !frontier.empty() && frontier.back() == maybe_innermost_id; } -// Helper class for BlockQuantization validation -// This class validates the following scheduling requirements for -// BlockQuantizationOp: -// 1. The Group ID must be derived from the innermost logical IDs -// 2. Merge operations must combine contiguous logical IDs -// 3. TIDx must follow the Group ID in the schedule -- that is when derived from -// the logical domain, group ID must be inner-most, the next "inner-most" should -// be TIDx (unless there is an ID with a unit trip count) -class BlockQuantizationValidationHelper { - public: - BlockQuantizationValidationHelper(const TensorView* tv, IterDomain* group_id) - : tv_(tv), group_id_(group_id) {} - - void run() { - // Trace back from Group ID to logical domain to see group is "inner-most" - traceGroupIdToLogicalDomain(); - // Find where to start the search for TIDx - IterDomain* restart_traversal_from = determineThreadXTraversalStart(); - // We start a DFS for TIDx taking the inner outputs of Split nodes first. - // We should not find a terminating ID that is not TIDx and has extent > 1. - validateThreadXFollowsGroupId(restart_traversal_from); - } - - private: - // Checks if a merge operation combines contiguous IDs from the logical - // domain. A merge is acceptable if its inputs can be traced back to IDs in - // the logical domain and those IDs are contiguous. - bool isAcceptableMerge(Merge* merge) const { - const auto& logical_domain = tv_->getLogicalDomain(); - auto ids_in_logical_domain = IterVisitor::getInputsTo({merge->out()}); - - // Collect positions of IDs in the logical domain - std::vector logical_positions; - for (auto id : ids_in_logical_domain) { - auto iter_domain = id->as(); - auto it = - std::find(logical_domain.begin(), logical_domain.end(), iter_domain); - if (it == logical_domain.end()) { - return false; // ID not found in logical domain - } - logical_positions.push_back(std::distance(logical_domain.begin(), it)); - } - - // Check contiguity: positions should be consecutive when sorted - if (logical_positions.size() > 1) { - std::sort(logical_positions.begin(), logical_positions.end()); - for (size_t i = 1; i < logical_positions.size(); ++i) { - if (logical_positions[i] != logical_positions[i - 1] + 1) { - return false; // Not contiguous - } - } - } - - return true; - } - - // M K - // │ │ - // ▼ ▼ - // ┌────────────┐ - // │ merge │ - // └─────┬──────┘ - // │ - // ▼ - // M*K - // ┌──────────┐ - // │ split ┼──┐ - // └─┬────────┘ │ - // ▼ ▼ - // (M*K)/4 4(G) - // ┌────────┐ - // │ split ┼────┐ - // └─┬──────┘ │ - // ▼ ▼ - // (M*K)/4 1(U) - // ┌─────────┐ - // │ split │ - // ┌─┼ ┼───┐ - // │ └─────────┘ │ - // ▼ ▼ - // (M*K)/4/128 128(Tx) - // - // Traces the Group ID backwards through splits and merges to ensure it - // derives from the innermost logical IDs. The diagram above shows a typical - // transformation chain from logical IDs (M, K) to the Group ID (G). - // - // Traversal rules: - // - For splits: Only traverse through the inner output (Group ID must be - // innermost) - // - For merges: Traverse through the inner input and validate contiguity - // - // Stores last_split_seen_: The last split encountered during traversal, - // which is used as the starting point for validating the TIDx path. - void traceGroupIdToLogicalDomain() { - NVF_ERROR( - group_id_ != nullptr, - "Expected a valid loop grouped ID for BlockQuantizationOp: ", - tv_->toString()); - - auto current_id = group_id_; - last_split_seen_ = nullptr; - - while (current_id->definition() != nullptr) { - auto def = current_id->definition(); +// Check if all merges leading to id from the logical domain are contiguous +// merges +bool areAllMergesContiguous(const TensorView* tv, IterDomain* id) { + const auto& logical_domain = tv->getLogicalDomain(); + std::deque frontier( + logical_domain.begin(), logical_domain.end()); - if (auto merge = dynamic_cast(def)) { - NVF_ERROR( - isAcceptableMerge(merge), - "Invalid merge found while tracing back the grouped ID for " - "BlockQuantizationOp. All inputs to merge must be from logical " - "domain or be outputs of other merges. TV: ", - tv_->toString()); - current_id = merge->inner(); - } else if (auto split = dynamic_cast(def)) { - NVF_ERROR( - current_id == split->inner(), - "The grouped ID must correspond to the innermost of all splits " - "from logical domains to loop domains for BlockQuantizationOp. " - "TV: ", - tv_->toString()); - last_split_seen_ = split; - current_id = split->in(); - } else { - NVF_THROW( - false, - "Unexpected definition found while tracing back the grouped ID for " - "BlockQuantizationOp: ", - tv_->toString()); - } - } - - NVF_ERROR( - current_id->definition() == nullptr && - current_id == tv_->getLogicalDomain().back(), - "The grouped ID must be the innermost logical domain ID for " - "BlockQuantizationOp: ", - tv_->toString()); - } - - // Determines the starting point for validating that TIDx follows the Group - // ID. Two cases: - // 1. If no splits were seen (last_split_seen_ == nullptr): - // Group ID comes directly from logical domain or via merges only. - // Start from the logical ID just before those feeding into Group ID. - // 2. Otherwise: - // Start from the outer output of the last split seen - as we must come up - // to the split from the group ID via the inner output. - IterDomain* determineThreadXTraversalStart() const { - if (last_split_seen_ == nullptr) { - // Case 1: Group ID derived directly from logical domain - auto ids_in_logical = IterVisitor::getInputsTo({group_id_}); - - // Validate all IDs feeding into Group ID have constant extents - for (auto id : ids_in_logical) { - auto iter_domain = id->as(); - NVF_ERROR( - iter_domain->extent()->isConstInt(), - "Expected all IDs feeding directly into Group ID to have constant " - "extents for BlockQuantizationOp: ", - tv_->toString()); - } - - // Ensure there are logical IDs left to derive thread IDs - const auto& logical_domain = tv_->getLogicalDomain(); - NVF_ERROR( - ids_in_logical.size() < logical_domain.size(), - "There aren't enough logical IDs to derive thread IDs: ", - tv_->toString()); - - // Return the logical ID just before the ones feeding into Group ID - return logical_domain[logical_domain.size() - ids_in_logical.size() - 1]; - } else { - // Case 2: Start from the outer output of the last split - return last_split_seen_->outer(); - } - } - - // Validates that TIDx follows the Group ID in the schedule using DFS. - // Starting from the given ID, traverses through splits and merges - // to ensure TIDx is reachable. We traverse through the split by taking the - // inner path first. Any terminating IDs that are not TIDx must have extent 1. - void validateThreadXFollowsGroupId(IterDomain* start_id) const { - std::stack to_visit; - to_visit.push(start_id); - - while (!to_visit.empty()) { - auto current_id = to_visit.top(); - to_visit.pop(); - - // Check terminating IDs (no uses) - if (current_id->uses().empty()) { - if (current_id->getParallelType() == ParallelType::TIDx) { - return; // Found TIDx - validation successful - } + auto all_exprs = DependencyCheck::getAllExprsBetween( + {logical_domain.begin(), logical_domain.end()}, {id}); - // Non-TIDx terminating IDs must have constant extent of 1 - NVF_ERROR( - current_id->extent()->isConstInt(), - "Only constant extent IDs are expected between TIDx and Group ID " - "in BlockQuantizationOp quantized output: ", - tv_->toInlineString()); - NVF_ERROR( - current_id->extent()->evaluate().as() == 1, - "Only constant extent IDs with extent of 1 are expected between " - "TIDx and Group ID in BlockQuantizationOp quantized output: ", - tv_->toInlineString()); - continue; - } - - NVF_ERROR( - current_id->uses().size() == 1, - "Expected single use for IDs in logical to loop transforms for " - "BlockQuantizationOp quantization output: ", - current_id->toString()); - - // Process the use expression - auto use_expr = current_id->uses().at(0); - if (auto merge = dynamic_cast(use_expr)) { - NVF_ERROR( - isAcceptableMerge(merge), - "Invalid merge found while tracing forward in the logical to loop " - "transforms for BlockQuantizationOp quantization output: ", - tv_->toString()); - to_visit.push(merge->out()); - } else if (auto split = dynamic_cast(use_expr)) { - // DFS: inner split first, then outer split - to_visit.push(split->outer()); - to_visit.push(split->inner()); - } else { - NVF_ERROR( - false, - "Unexpected use of an ID found while tracing forward in the " - "logical to loop transforms for BlockQuantizationOp quantization " - "output: ", - tv_->toString()); - } - } - } - - private: - const TensorView* tv_ = nullptr; - IterDomain* group_id_; - Split* last_split_seen_ = nullptr; -}; + // Stop on first non-contiguous merge and return false + return traverseFrontierWithContiguityCheck( + frontier, all_exprs, /*stop_on_noncontiguous=*/true); +} // Expr-specific validaion // @@ -665,16 +457,16 @@ class ExprValidator : public OptOutDispatch { // has an extent of 1. // The Group ID has an extent of 4/8 depending on the data type. // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. + // The following are more complex checks that look at the schedule. - // These checks are implemented using the helper class - // BlockQuantizationValidationHelper. // Our aims for the following checks are to ensure that the group ID is // contiguous and unit stride, and then after group ID, we have TIDx. Such // that (G -- extent of GID) * ThreadIdx.x + GID is contiguous. - // We do so by checking that the group ID unit stride. It is derived from the - // innermost logical IDs via merges and inner splits only. Next we check that + // We do so by checking that the group ID is unit stride. It is derived from + // the innermost logical IDs via contiguous merges only. Next we check that // TIDx in the next inner-most ID, and if there was any other ID between TIDx - // and group ID then it must have an extent of 1. + // and group ID then it must have an extent of 1. TIDx must also be derived + // from contiguous merges of the logical IDs. void handle(BlockQuantizationOp* bqop) final { auto inp_tv = bqop->input(0)->as(); auto quantized_output = bqop->quantizedOutput()->as(); @@ -775,9 +567,82 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); - // Helper to check to the most involved scheduling requirements. - BlockQuantizationValidationHelper helper(quantized_output, grouped_id); - helper.run(); + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + + // Next we check the following scheduling requirements for + // BlockQuantizationOp - the above figure is an example of such a schedule. + // 1. The Group ID must be derived from the innermost logical IDs + // 2. TIDx must follow the Group ID in the schedule -- that is when derived + // from the logical domain, group ID must be inner-most, the next + // "inner-most" should be TIDx (unless there is an ID with a unit trip + // count) + // 3. All merges involved from logical domains to group and thread ID must + // combine contiguous logical IDs + + // This will get the xforms from logical to loop and apply then on the + // logical domain. We will get a loop domain minus the reordering. + std::vector ids_to_transform = + scheduler_utils::computeLoopDomainFromLogical(quantized_output); + + // The grouped ID must correspond to the innermost + NVF_ERROR( + ids_to_transform.back() == grouped_id, + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for BlockQuantizationOp. " + "TV: ", + quantized_output->toString()); + + // Iterate from the back to find TIDx, skipping group_id (last element) + // Ensure all IDs between group_id and TIDx have extent 1 + for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); + ++it) { + if ((*it)->getParallelType() == ParallelType::TIDx) { + break; + } + // All non-TIDx IDs between Group ID and TIDx must have extent of 1 + NVF_ERROR( + (*it)->extent()->isConstInt() && + (*it)->extent()->evaluate().as() == 1, + "Expected IDs between Group ID and TIDx to have extent of 1 for " + "BlockQuantizationOp: ", + quantized_output->toString()); + } + + NVF_ERROR( + areAllMergesContiguous(quantized_output, grouped_id), + "All merge operations deriving the grouped ID must combine " + "contiguous IDs from the logical domain for BlockQuantizationOp: ", + quantized_output->toString()); + + NVF_ERROR( + areAllMergesContiguous(quantized_output, thread_x), + "All merge operations deriving the TIDx ID must combine " + "contiguous IDs from the logical domain for BlockQuantizationOp: ", + quantized_output->toString()); } }; diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 790be9165fc..32ef047117f 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2282,14 +2282,21 @@ void applyTransforms( } } -// Returns a permutation reordering the loop domain of the tensor view as the +// Compute a new loop domain (without the reordering) by applying transforms to // logical domain -std::vector domainReorderAsLogicalMap(TensorView* tv) { +std::vector computeLoopDomainFromLogical(TensorView* tv) { auto transform_exprs = DependencyCheck::getAllExprsBetween( {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); std::vector ids_to_transform = tv->getLogicalDomain(); applyTransforms(ids_to_transform, transform_exprs); + return ids_to_transform; +} + +// Returns a permutation reordering the loop domain of the tensor view as the +// logical domain +std::vector domainReorderAsLogicalMap(TensorView* tv) { + std::vector ids_to_transform = computeLoopDomainFromLogical(tv); std::optional> permutation = ir_utils::computePermutation(ids_to_transform, tv->getLoopDomain()); NVF_ERROR( diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index c93bb04bae5..024989a026c 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -700,6 +700,12 @@ std::vector domainReorderAsLogicalMap(TensorView* tv); std::unordered_map reorderLogicalAsAllocationMap( TensorView* tv); +// Helper function that computes the loop domain by applying all transformations +// from the logical domain to the loop domain. Returns a vector of IterDomains +// representing what the loop domain minus reorderings would look like after +// applying the transformations to the logical domain. +std::vector computeLoopDomainFromLogical(TensorView* tv); + // Generates an old to new map to reorder tv's loop domain as its allocation // order. Allocation domain is transformed to find a permutation of the loop // domain that satisfies the order in allocation domain. diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 3042bec8e00..35c60480340 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -662,8 +662,8 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { assertCompilationFails( setup.fusion.get(), {createTestInput()}, - "Only constant extent IDs with extent of 1 are expected between TIDx " - "and Group ID in BlockQuantizationOp quantized output"); + "Expected IDs between Group ID and TIDx to have extent of 1 for " + "BlockQuantizationOp:"); } // When running validation checks we traverse from loop to logical domain @@ -709,9 +709,8 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { assertCompilationFails( setup.fusion.get(), {createTestInput(/*dim=*/3)}, - "Invalid merge found while tracing back the grouped ID for " - "BlockQuantizationOp. All inputs to merge must be from logical domain " - "or be outputs of other merges"); + "All merge operations deriving the grouped ID must combine contiguous " + "IDs from the logical domain for BlockQuantizationOp"); } TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { From dfd75536fd707072f62e42986b1d2e38a232f24d Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 11:57:50 -0800 Subject: [PATCH 55/79] foramt --- csrc/device_lower/validation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index c6e28f1c7c2..ebd7919973e 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -594,7 +594,7 @@ class ExprValidator : public OptOutDispatch { // (M*K)/4/128 128(Tx) // Next we check the following scheduling requirements for - // BlockQuantizationOp - the above figure is an example of such a schedule. + // BlockQuantizationOp - the above figure is an example of a valid schedule. // 1. The Group ID must be derived from the innermost logical IDs // 2. TIDx must follow the Group ID in the schedule -- that is when derived // from the logical domain, group ID must be inner-most, the next @@ -603,7 +603,7 @@ class ExprValidator : public OptOutDispatch { // 3. All merges involved from logical domains to group and thread ID must // combine contiguous logical IDs - // This will get the xforms from logical to loop and apply then on the + // This will get the xforms from logical to loop and apply them on the // logical domain. We will get a loop domain minus the reordering. std::vector ids_to_transform = scheduler_utils::computeLoopDomainFromLogical(quantized_output); From 00f5f946039af4b04f922d560a69bac37b86a5aa Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 12:24:49 -0800 Subject: [PATCH 56/79] address greptile comments --- csrc/device_lower/validation.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index ebd7919973e..fc10098182f 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -618,9 +618,11 @@ class ExprValidator : public OptOutDispatch { // Iterate from the back to find TIDx, skipping group_id (last element) // Ensure all IDs between group_id and TIDx have extent 1 + bool found_tidx = false; for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); ++it) { - if ((*it)->getParallelType() == ParallelType::TIDx) { + if (*it == thread_x) { + found_tidx = true; break; } // All non-TIDx IDs between Group ID and TIDx must have extent of 1 @@ -632,6 +634,12 @@ class ExprValidator : public OptOutDispatch { quantized_output->toString()); } + NVF_ERROR( + found_tidx, + "TIDx must follow the Group ID in the schedule for " + "BlockQuantizationOp: ", + quantized_output->toString()); + NVF_ERROR( areAllMergesContiguous(quantized_output, grouped_id), "All merge operations deriving the grouped ID must combine " From 6b610609b10b571cd18d4aab36fa20eb2a337b4e Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 16:45:23 -0800 Subject: [PATCH 57/79] edit assert error handler --- tests/cpp/test_low_precision_recipe.cpp | 51 ++++++------------------- 1 file changed, 11 insertions(+), 40 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 35c60480340..9c5c95b431a 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -407,35 +407,14 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { class BlockQuantizationValidationTest : public BlackwellBase { protected: - // Helper function to create test input tensor - at::Tensor createTestInput(int64_t dim = 2) { - if (dim == 2) { - return at::randn({1024, 1024}, at::device(at::kCUDA).dtype(at::kFloat)); - } else if (dim == 3) { - return at::randn({16, 64, 1024}, at::device(at::kCUDA).dtype(at::kFloat)); - } else { - throw std::runtime_error("Unsupported dimension for createTestInput"); - } - } - - // Helper function to assert compilation fails with expected error message - void assertCompilationFails( - Fusion* fusion, - const std::vector& inputs, - const char* expected_error_msg) { + // Helper function to assert compilation fails + void assertCompilationFails(Fusion* fusion, const char* expected_error_msg) { KernelExecutor ke; - EXPECT_THROW( - { - try { - ke.compile(fusion, inputs); - } catch (const std::exception& e) { - EXPECT_THAT(e.what(), ::testing::HasSubstr(expected_error_msg)) - << "Expected error message containing: \"" << expected_error_msg - << "\"\nActual error: " << e.what(); - throw; // Re-throw for EXPECT_THROW to catch - } - }, - std::exception); + + EXPECT_THAT( + [&]() { GpuLower(fusion).run(); }, + testing::ThrowsMessage( + testing::HasSubstr(expected_error_msg))); } // Helper function to create a fusion with blockQuantize and apply scheduling @@ -505,8 +484,7 @@ TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { fusion->addOutput(quantization_results.block_scales); fusion->addOutput(t_out); - assertCompilationFails( - fusion.get(), {createTestInput()}, "Input must be a local memory tensor"); + assertCompilationFails(fusion.get(), "Input must be a local memory tensor"); } // Quantized output is written to global memory - not valid @@ -524,9 +502,7 @@ TEST_F(BlockQuantizationValidationTest, QuantizedOutputMustBeInLocalMemory) { fusion->addOutput(quantization_results.quantized_tensor); assertCompilationFails( - fusion.get(), - {createTestInput()}, - "Quantized output must be a local memory tensor"); + fusion.get(), "Quantized output must be a local memory tensor"); } // Block scaling factor is written to local memory - not valid @@ -548,9 +524,7 @@ TEST_F( fusion->addOutput(tv_quantized_out); assertCompilationFails( - fusion.get(), - {createTestInput()}, - "Block scaling factor must be a global memory tensor"); + fusion.get(), "Block scaling factor must be a global memory tensor"); } // Group ID must be the innermost of all splits from logical domains to loop @@ -586,7 +560,6 @@ TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { assertCompilationFails( setup.fusion.get(), - {createTestInput()}, "The grouped ID must correspond to the innermost of all splits from " "logical domains to loop domains for BlockQuantizationOp"); } @@ -622,7 +595,7 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { assertCompilationFails( setup.fusion.get(), - {createTestInput()}, + "Expected non-TID/BID/Group ID to have extent of 1 for " "BlockQuantizationOp"); } @@ -661,7 +634,6 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { assertCompilationFails( setup.fusion.get(), - {createTestInput()}, "Expected IDs between Group ID and TIDx to have extent of 1 for " "BlockQuantizationOp:"); } @@ -708,7 +680,6 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { assertCompilationFails( setup.fusion.get(), - {createTestInput(/*dim=*/3)}, "All merge operations deriving the grouped ID must combine contiguous " "IDs from the logical domain for BlockQuantizationOp"); } From 490b6a778bb460c69d6a78cf780a38bc0a5c6b16 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 16:53:44 -0800 Subject: [PATCH 58/79] cleanup --- tests/cpp/test_low_precision_recipe.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 9c5c95b431a..199e03b80dc 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -409,8 +409,6 @@ class BlockQuantizationValidationTest : public BlackwellBase { protected: // Helper function to assert compilation fails void assertCompilationFails(Fusion* fusion, const char* expected_error_msg) { - KernelExecutor ke; - EXPECT_THAT( [&]() { GpuLower(fusion).run(); }, testing::ThrowsMessage( From 25da8be93018fafbfe68c848c6c0cdf101475202 Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 4 Nov 2025 16:59:40 -0800 Subject: [PATCH 59/79] more clean up --- tests/cpp/test_low_precision_recipe.cpp | 63 +++++++++++++------------ 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 199e03b80dc..e1e75ce867c 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -407,14 +407,6 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { class BlockQuantizationValidationTest : public BlackwellBase { protected: - // Helper function to assert compilation fails - void assertCompilationFails(Fusion* fusion, const char* expected_error_msg) { - EXPECT_THAT( - [&]() { GpuLower(fusion).run(); }, - testing::ThrowsMessage( - testing::HasSubstr(expected_error_msg))); - } - // Helper function to create a fusion with blockQuantize and apply scheduling struct FusionSetup { std::unique_ptr fusion; @@ -482,7 +474,10 @@ TEST_F(BlockQuantizationValidationTest, InputMustBeInLocalMemory) { fusion->addOutput(quantization_results.block_scales); fusion->addOutput(t_out); - assertCompilationFails(fusion.get(), "Input must be a local memory tensor"); + EXPECT_THAT( + [&]() { GpuLower(fusion.get()).run(); }, + testing::ThrowsMessage( + testing::HasSubstr("Input must be a local memory tensor"))); } // Quantized output is written to global memory - not valid @@ -499,8 +494,10 @@ TEST_F(BlockQuantizationValidationTest, QuantizedOutputMustBeInLocalMemory) { fusion->addOutput(quantization_results.block_scales); fusion->addOutput(quantization_results.quantized_tensor); - assertCompilationFails( - fusion.get(), "Quantized output must be a local memory tensor"); + EXPECT_THAT( + [&]() { GpuLower(fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "Quantized output must be a local memory tensor"))); } // Block scaling factor is written to local memory - not valid @@ -521,8 +518,10 @@ TEST_F( fusion->addOutput(tv_block_scales); fusion->addOutput(tv_quantized_out); - assertCompilationFails( - fusion.get(), "Block scaling factor must be a global memory tensor"); + EXPECT_THAT( + [&]() { GpuLower(fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "Block scaling factor must be a global memory tensor"))); } // Group ID must be the innermost of all splits from logical domains to loop @@ -556,10 +555,11 @@ TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { } } - assertCompilationFails( - setup.fusion.get(), - "The grouped ID must correspond to the innermost of all splits from " - "logical domains to loop domains for BlockQuantizationOp"); + EXPECT_THAT( + [&]() { GpuLower(setup.fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "The grouped ID must correspond to the innermost of all splits from " + "logical domains to loop domains for BlockQuantizationOp"))); } // We do not allow IDs of types serial, unroll, unswitch to have extent > 1 @@ -591,11 +591,11 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { } } - assertCompilationFails( - setup.fusion.get(), - - "Expected non-TID/BID/Group ID to have extent of 1 for " - "BlockQuantizationOp"); + EXPECT_THAT( + [&]() { GpuLower(setup.fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "Expected non-TID/BID/Group ID to have extent of 1 for " + "BlockQuantizationOp"))); } // The runtime kernel for block quantization expects TIDx to access contiguous @@ -630,10 +630,11 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { } } - assertCompilationFails( - setup.fusion.get(), - "Expected IDs between Group ID and TIDx to have extent of 1 for " - "BlockQuantizationOp:"); + EXPECT_THAT( + [&]() { GpuLower(setup.fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "Expected IDs between Group ID and TIDx to have extent of 1 for " + "BlockQuantizationOp:"))); } // When running validation checks we traverse from loop to logical domain @@ -676,10 +677,12 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { } } - assertCompilationFails( - setup.fusion.get(), - "All merge operations deriving the grouped ID must combine contiguous " - "IDs from the logical domain for BlockQuantizationOp"); + EXPECT_THAT( + [&]() { GpuLower(setup.fusion.get()).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "All merge operations deriving the grouped ID must combine " + "contiguous " + "IDs from the logical domain for BlockQuantizationOp"))); } TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { From 2059878ad0ae219f3a9a79f449bfafe7cc8649b3 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 5 Nov 2025 05:33:10 -0800 Subject: [PATCH 60/79] refactor --- csrc/device_lower/validation.cpp | 122 +++++++++++++++---------------- csrc/scheduler/utils.cpp | 17 ++--- csrc/scheduler/utils.h | 3 +- 3 files changed, 68 insertions(+), 74 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index fc10098182f..652225c34cd 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -213,53 +213,44 @@ void validateCpAsyncBulk(const std::vector& tvs) { // If stop_on_noncontiguous is true, stops traversal and returns false on first // non-contiguous merge. Otherwise, removes non-contiguous merges from frontier // and continues. -bool traverseFrontierWithContiguityCheck( +void traverseFrontierWithContiguityCheck( std::deque& frontier, - const std::vector& exprs, - bool stop_on_noncontiguous) { - for (auto expr : exprs) { - // expr is skipped if any of the inputs is missing. - if (auto merge = dynamic_cast(expr)) { - // Check if this merge is logically contiguous merge, that is, - // both of the two inputs are adjacent to each other - auto outer_it = std::ranges::find(frontier, merge->outer()); - if (outer_it == frontier.end()) { - continue; - } - auto inner_it = std::ranges::find(frontier, merge->inner()); - if (inner_it == frontier.end()) { - continue; - } - auto outer_pos = std::distance(frontier.begin(), outer_it); - auto inner_pos = std::distance(frontier.begin(), inner_it); - - bool is_contig = outer_pos + 1 == inner_pos; + Expr* expr) { + // expr is skipped if any of the inputs is missing. + if (auto merge = dynamic_cast(expr)) { + // Check if this merge is logically contiguous merge, that is, + // both of the two inputs are adjacent to each other + auto outer_it = std::ranges::find(frontier, merge->outer()); + if (outer_it == frontier.end()) { + return; + } + auto inner_it = std::ranges::find(frontier, merge->inner()); + if (inner_it == frontier.end()) { + return; + } + auto outer_pos = std::distance(frontier.begin(), outer_it); + auto inner_pos = std::distance(frontier.begin(), inner_it); - if (!is_contig && stop_on_noncontiguous) { - // Found a non-contiguous merge - return false; - } + bool is_contig = outer_pos + 1 == inner_pos; - frontier.erase(inner_it); + frontier.erase(inner_it); - // If it's contig, we can continue the analysis by proceeding to - // the output. If not, no further analysis is possible, so the - // two inputs are just removed from the frontier list - if (is_contig) { - frontier[outer_pos] = merge->out(); - } else { - frontier.erase(outer_it); - } - } else if (auto split = dynamic_cast(expr)) { - auto in_it = std::ranges::find(frontier, split->in()); - if (in_it == frontier.end()) { - continue; - } - frontier.insert(in_it + 1, split->inner()); - *in_it = split->outer(); + // If it's contig, we can continue the analysis by proceeding to + // the output. If not, no further analysis is possible, so the + // two inputs are just removed from the frontier list + if (is_contig) { + frontier[outer_pos] = merge->out(); + } else { + frontier.erase(outer_it); + } + } else if (auto split = dynamic_cast(expr)) { + auto in_it = std::ranges::find(frontier, split->in()); + if (in_it == frontier.end()) { + return; } + frontier.insert(in_it + 1, split->inner()); + *in_it = split->outer(); } - return true; } // Check if maybe_innermost_id is derived from base_id and corresponds to the @@ -273,9 +264,9 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { std::deque frontier; frontier.push_back(base_id); - // Don't stop on non-contiguous merges; remove them from frontier and continue - traverseFrontierWithContiguityCheck( - frontier, exprs, /*stop_on_noncontiguous=*/false); + for (auto expr : exprs) { + traverseFrontierWithContiguityCheck(frontier, expr); + } // Once the traversal is done, if the target id located at the // rightmost position of the frontier list, it is guaranteed to @@ -283,21 +274,6 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { return !frontier.empty() && frontier.back() == maybe_innermost_id; } -// Check if all merges leading to id from the logical domain are contiguous -// merges -bool areAllMergesContiguous(const TensorView* tv, IterDomain* id) { - const auto& logical_domain = tv->getLogicalDomain(); - std::deque frontier( - logical_domain.begin(), logical_domain.end()); - - auto all_exprs = DependencyCheck::getAllExprsBetween( - {logical_domain.begin(), logical_domain.end()}, {id}); - - // Stop on first non-contiguous merge and return false - return traverseFrontierWithContiguityCheck( - frontier, all_exprs, /*stop_on_noncontiguous=*/true); -} - // Expr-specific validaion // // TODO: Move individual validations to here, e.g., @@ -605,8 +581,23 @@ class ExprValidator : public OptOutDispatch { // This will get the xforms from logical to loop and apply them on the // logical domain. We will get a loop domain minus the reordering. + + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()}, + {quantized_output->getLoopDomain().begin(), + quantized_output->getLoopDomain().end()}); + std::vector ids_to_transform = - scheduler_utils::computeLoopDomainFromLogical(quantized_output); + quantized_output->getLogicalDomain(); + + std::deque frontier( + quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()); + scheduler_utils::applyTransforms( + ids_to_transform, transform_exprs, [&frontier](Expr* expr) { + traverseFrontierWithContiguityCheck(frontier, expr); + }); // The grouped ID must correspond to the innermost NVF_ERROR( @@ -640,14 +631,19 @@ class ExprValidator : public OptOutDispatch { "BlockQuantizationOp: ", quantized_output->toString()); + // Check if grouped_is in frontier. + auto grouped_it = + std::ranges::find(frontier.begin(), frontier.end(), grouped_id); NVF_ERROR( - areAllMergesContiguous(quantized_output, grouped_id), + grouped_it != frontier.end(), "All merge operations deriving the grouped ID must combine " "contiguous IDs from the logical domain for BlockQuantizationOp: ", quantized_output->toString()); - + // Do the same for thread_x + auto threadx_it = + std::ranges::find(frontier.begin(), frontier.end(), thread_x); NVF_ERROR( - areAllMergesContiguous(quantized_output, thread_x), + threadx_it != frontier.end(), "All merge operations deriving the TIDx ID must combine " "contiguous IDs from the logical domain for BlockQuantizationOp: ", quantized_output->toString()); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 32ef047117f..b66b2eda8b0 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2267,7 +2267,8 @@ void applyResizeTransform(Resize* resize, std::vector& ids) { void applyTransforms( std::vector& ids_to_transform, - const std::vector& transform_exprs) { + const std::vector& transform_exprs, + std::optional> after_transform) { for (auto* expr : transform_exprs) { if (Split* split = dynamic_cast(expr)) { applySplitTransform(split, ids_to_transform); @@ -2279,24 +2280,20 @@ void applyTransforms( NVF_ERROR(expr != nullptr); NVF_THROW("Unexpected expression: ", expr->toString()); } + if (after_transform) { + (*after_transform)(expr); + } } } -// Compute a new loop domain (without the reordering) by applying transforms to +// Returns a permutation reordering the loop domain of the tensor view as the // logical domain -std::vector computeLoopDomainFromLogical(TensorView* tv) { +std::vector domainReorderAsLogicalMap(TensorView* tv) { auto transform_exprs = DependencyCheck::getAllExprsBetween( {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); std::vector ids_to_transform = tv->getLogicalDomain(); applyTransforms(ids_to_transform, transform_exprs); - return ids_to_transform; -} - -// Returns a permutation reordering the loop domain of the tensor view as the -// logical domain -std::vector domainReorderAsLogicalMap(TensorView* tv) { - std::vector ids_to_transform = computeLoopDomainFromLogical(tv); std::optional> permutation = ir_utils::computePermutation(ids_to_transform, tv->getLoopDomain()); NVF_ERROR( diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 024989a026c..5f811eeb369 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -684,7 +684,8 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos); // inner dimension. void applyTransforms( std::vector& ids_to_transform, - const std::vector& transform_exprs); + const std::vector& transform_exprs, + std::optional> after_transform = std::nullopt); // Generates a permutation to reorder tv's domain as the logical order. // Priority is given to inner most dimensions for example: From b67f20507f377439c0d7a8fd64ed0e0f203ff89e Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 5 Nov 2025 06:36:49 -0800 Subject: [PATCH 61/79] refactor using lambda and edit comments --- csrc/device_lower/validation.cpp | 27 +++++++++++++------------ csrc/scheduler/utils.cpp | 6 +++--- csrc/scheduler/utils.h | 11 +++------- tests/cpp/test_low_precision_recipe.cpp | 2 +- 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 652225c34cd..3c58760d891 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -208,11 +208,8 @@ void validateCpAsyncBulk(const std::vector& tvs) { } } -// Traverse through the expressions, updating the frontier based on merge and -// split operations. Returns true if all merges encountered are contiguous. -// If stop_on_noncontiguous is true, stops traversal and returns false on first -// non-contiguous merge. Otherwise, removes non-contiguous merges from frontier -// and continues. +// For each expressions, update the frontier based on merge and +// split operations. Removes non-contiguous merges from frontier. void traverseFrontierWithContiguityCheck( std::deque& frontier, Expr* expr) { @@ -232,7 +229,6 @@ void traverseFrontierWithContiguityCheck( auto inner_pos = std::distance(frontier.begin(), inner_it); bool is_contig = outer_pos + 1 == inner_pos; - frontier.erase(inner_it); // If it's contig, we can continue the analysis by proceeding to @@ -250,12 +246,15 @@ void traverseFrontierWithContiguityCheck( } frontier.insert(in_it + 1, split->inner()); *in_it = split->outer(); + } else { + NVF_ERROR(expr != nullptr); + NVF_THROW("Unexpected expression: ", expr->toString()); } } // Check if maybe_innermost_id is derived from base_id and corresponds to the // innermost subregion of base_id. The split/merge exprs between -// based_id and id must not include any ID that is not produced from +// base_id and id must not include any ID that is not produced from // base_id. bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) { auto exprs = @@ -577,10 +576,7 @@ class ExprValidator : public OptOutDispatch { // "inner-most" should be TIDx (unless there is an ID with a unit trip // count) // 3. All merges involved from logical domains to group and thread ID must - // combine contiguous logical IDs - - // This will get the xforms from logical to loop and apply them on the - // logical domain. We will get a loop domain minus the reordering. + // combine contiguous IDs auto transform_exprs = DependencyCheck::getAllExprsBetween( {quantized_output->getLogicalDomain().begin(), @@ -594,12 +590,17 @@ class ExprValidator : public OptOutDispatch { std::deque frontier( quantized_output->getLogicalDomain().begin(), quantized_output->getLogicalDomain().end()); + + // This will get the xforms from logical to loop and apply them on the + // logical domain. We will get a loop domain minus the reordering. + // This pass also removes all IDs from frontier that were derived using + // non-contiguous merges. scheduler_utils::applyTransforms( ids_to_transform, transform_exprs, [&frontier](Expr* expr) { traverseFrontierWithContiguityCheck(frontier, expr); }); - // The grouped ID must correspond to the innermost + // The grouped ID must correspond to the innermost loop-like domain NVF_ERROR( ids_to_transform.back() == grouped_id, "The grouped ID must correspond to the innermost of all splits " @@ -631,7 +632,7 @@ class ExprValidator : public OptOutDispatch { "BlockQuantizationOp: ", quantized_output->toString()); - // Check if grouped_is in frontier. + // Check if grouped_is in frontier auto grouped_it = std::ranges::find(frontier.begin(), frontier.end(), grouped_id); NVF_ERROR( diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index b66b2eda8b0..ac422bc252d 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -2268,7 +2268,7 @@ void applyResizeTransform(Resize* resize, std::vector& ids) { void applyTransforms( std::vector& ids_to_transform, const std::vector& transform_exprs, - std::optional> after_transform) { + std::optional> post_transform) { for (auto* expr : transform_exprs) { if (Split* split = dynamic_cast(expr)) { applySplitTransform(split, ids_to_transform); @@ -2280,8 +2280,8 @@ void applyTransforms( NVF_ERROR(expr != nullptr); NVF_THROW("Unexpected expression: ", expr->toString()); } - if (after_transform) { - (*after_transform)(expr); + if (post_transform) { + (*post_transform)(expr); } } } diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 5f811eeb369..0167a32b51e 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -681,11 +681,12 @@ bool breakIsDisjoint(std::vector group_ids, int64_t pos); // Update the vector of ids_to_transform as progressing through the // `transform_exprs`. We'll always insert the result of split in the // location of the input, and insert the merge result in the position of the -// inner dimension. +// inner dimension. Optionally accepts a callback after each transform is +// applied for analysis of the expr nodes. void applyTransforms( std::vector& ids_to_transform, const std::vector& transform_exprs, - std::optional> after_transform = std::nullopt); + std::optional> post_transform = std::nullopt); // Generates a permutation to reorder tv's domain as the logical order. // Priority is given to inner most dimensions for example: @@ -701,12 +702,6 @@ std::vector domainReorderAsLogicalMap(TensorView* tv); std::unordered_map reorderLogicalAsAllocationMap( TensorView* tv); -// Helper function that computes the loop domain by applying all transformations -// from the logical domain to the loop domain. Returns a vector of IterDomains -// representing what the loop domain minus reorderings would look like after -// applying the transformations to the logical domain. -std::vector computeLoopDomainFromLogical(TensorView* tv); - // Generates an old to new map to reorder tv's loop domain as its allocation // order. Allocation domain is transformed to find a permutation of the loop // domain that satisfies the order in allocation domain. diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index e1e75ce867c..727b20c56d1 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -407,7 +407,6 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { class BlockQuantizationValidationTest : public BlackwellBase { protected: - // Helper function to create a fusion with blockQuantize and apply scheduling struct FusionSetup { std::unique_ptr fusion; TensorView* tv_data_hp; @@ -417,6 +416,7 @@ class BlockQuantizationValidationTest : public BlackwellBase { TensorView* t_out; }; + // Helper function to create a fusion with blockQuantize and apply scheduling FusionSetup createBlockQuantizeFusion(int64_t dim = 2) { FusionSetup setup; setup.fusion = std::make_unique(); From d1f2cbcf3fe910c47399b3c9300256c34c454699 Mon Sep 17 00:00:00 2001 From: Protonu Date: Wed, 5 Nov 2025 13:34:08 -0500 Subject: [PATCH 62/79] Update csrc/device_lower/validation.cpp Co-authored-by: Naoya Maruyama --- csrc/device_lower/validation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 3c58760d891..3e48cd778ea 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -634,7 +634,7 @@ class ExprValidator : public OptOutDispatch { // Check if grouped_is in frontier auto grouped_it = - std::ranges::find(frontier.begin(), frontier.end(), grouped_id); + std::ranges::find(frontier, grouped_id); NVF_ERROR( grouped_it != frontier.end(), "All merge operations deriving the grouped ID must combine " From 9f582a8d192acd1b42a3a6ba3fb7912adf650456 Mon Sep 17 00:00:00 2001 From: Protonu Date: Wed, 5 Nov 2025 13:48:05 -0500 Subject: [PATCH 63/79] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/device_lower/validation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 3e48cd778ea..effd5324e81 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -632,7 +632,7 @@ class ExprValidator : public OptOutDispatch { "BlockQuantizationOp: ", quantized_output->toString()); - // Check if grouped_is in frontier + // Check if grouped_id in frontier auto grouped_it = std::ranges::find(frontier, grouped_id); NVF_ERROR( From f30ab2e0b5673bcc38152fe72664c3945f6987c5 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 5 Nov 2025 13:55:41 -0800 Subject: [PATCH 64/79] edit comments --- csrc/device_lower/validation.cpp | 65 ++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index effd5324e81..1715523ee84 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -424,24 +424,47 @@ class ExprValidator : public OptOutDispatch { } } - // Basic checks: + // The block quantization operator is implemented via a runtime function. + // This runtime function expects the inputs to be in local memory. The + // quantized output will also be in local memory, but the block scaling + // factors will be written out to global memory. The device runtime currently + // works on 4 elements per thread (8 for bf16/fp16) - this will be expanded + // later to support 2, 4, and 8 (only bf16/fp16) per thread. The runtime + // function is based on a parallelization scheme that expects TIDx and BIDx, + // and optionally TIDy and BIDy. 3D parallelization is not supported. Based + // on the above, we have the following basic validation checks: + // Input is in local memory. // Block scaling factor is in global memory and // quantized output is in local memory. - // Any loop ID that is not TID(x/y). BID(x/y) or Group - // has an extent of 1. - // The Group ID has an extent of 4/8 depending on the data type. + // The Group ID has an extent of 4/8 depending on the data + // type. // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. - // The following are more complex checks that look at the schedule. - // Our aims for the following checks are to ensure that the group ID is - // contiguous and unit stride, and then after group ID, we have TIDx. Such - // that (G -- extent of GID) * ThreadIdx.x + GID is contiguous. - // We do so by checking that the group ID is unit stride. It is derived from - // the innermost logical IDs via contiguous merges only. Next we check that - // TIDx in the next inner-most ID, and if there was any other ID between TIDx - // and group ID then it must have an extent of 1. TIDx must also be derived - // from contiguous merges of the logical IDs. + // For this op, the indices for block scaling factor is partially computed + // in nvfuser's index computation. It is done do by linearizing the logical + // index of the quantized outputs and the extents of the allocation domain + // of the quantized output. This index is passed to the runtime function, + // where is it divided by 16 (blocksize) to compute the output index for block + // scaling factor. Because of this indexing scheme we have to put the + // following restrictions. Our aim for the following checks is to ensure that + // the group ID is contiguous and has unit stride, and then after the group + // ID, we have TIDx, such that (G -- extent of GID) * ThreadIdx.x + GID is + // contiguous. We have these restrictions because 4 threads (x) will be + // working on contiguous data in the input (actually #threads * + // #elements_per_thread == blocksize(16)) - so we conservatively want all + // threads(x) to be accessing contiguous data. + + // We do so by checking that the group ID has unit stride. + // It should be derived from the innermost logical IDs via contiguous merges + // only. + // Next, we check that TIDx is the next inner-most ID, and if there is + // any other ID between TIDx and the group ID, then it must have an extent + // of 1. + // TIDx must also be derived from contiguous merges of the logical IDs. + // Any loop ID that is not TIDx, TIDy, BIDx, BIDy, or Group + // has an extent of 1. (we don't want the runtime kernel to be called multiple + // times by a thread). void handle(BlockQuantizationOp* bqop) final { auto inp_tv = bqop->input(0)->as(); auto quantized_output = bqop->quantizedOutput()->as(); @@ -471,10 +494,21 @@ class ExprValidator : public OptOutDispatch { !quantized_output->hasAllocation(), "Quantized output must not have an allocation domain."); - // TODO: Relax this for swizzled block scaling factor outputs + // TODO: Relax these for swizzled block scaling factor outputs + // When scaling will be swizzled we will need to allow these checks + // to be relaxed, but we will need to ensure that the swizzling + // allocation allowed is a fixed pattern: + // 2D logical and 5D allocation domain. + // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts NVF_ERROR( !block_scaling_factor->hasAllocation(), "Block scaling factor must not have an allocation domain."); + NVF_ERROR( + std::all_of( + block_scaling_factor->getContiguity().begin(), + block_scaling_factor->getContiguity().end(), + [](std::optional c) { return c.value_or(true); }), + "Block scaling factor not contiguous"); IterDomain* grouped_id = nullptr; IterDomain* thread_x = nullptr; @@ -633,8 +667,7 @@ class ExprValidator : public OptOutDispatch { quantized_output->toString()); // Check if grouped_id in frontier - auto grouped_it = - std::ranges::find(frontier, grouped_id); + auto grouped_it = std::ranges::find(frontier, grouped_id); NVF_ERROR( grouped_it != frontier.end(), "All merge operations deriving the grouped ID must combine " From 521461c764dd295fbff15f9f70a30faae912c381 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 5 Nov 2025 20:45:10 -0800 Subject: [PATCH 65/79] Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/device_lower/validation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 1715523ee84..887d11c09c7 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -237,7 +237,7 @@ void traverseFrontierWithContiguityCheck( if (is_contig) { frontier[outer_pos] = merge->out(); } else { - frontier.erase(outer_it); + frontier.erase(frontier.begin() + outer_pos); } } else if (auto split = dynamic_cast(expr)) { auto in_it = std::ranges::find(frontier, split->in()); From 4510a77a870367f4aa465e0f4160a5b24c0118e5 Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 6 Nov 2025 06:02:55 -0800 Subject: [PATCH 66/79] allows bq kernel to take 2,4,8 elem per thread --- csrc/codegen.cpp | 13 +++--- csrc/device_lower/validation.cpp | 12 ++--- runtime/block_quantization_kernels.cu | 58 +++++++++---------------- tests/cpp/test_low_precision_recipe.cpp | 31 +++++++++---- 4 files changed, 57 insertions(+), 57 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 865436d499f..c4639d31309 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1816,8 +1816,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { void handle(const BlockQuantizationOp* bqop) final { // This operator is plumbed down to a runtime function call. // One of the assumptions is that the device runtime expects - // 4 consecutive inputs (8 for FB16) per thread. We achieve this by having - // the input tv scheduler to have the inner dimension grouped by 4/8. + // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, + // 4, or 8 for Half. We achieve this by having the input tv scheduler to + // have the inner dimension grouped by 4/8. auto output = bqop->quantizedOutput()->as()->view(); int64_t group_size = 1; @@ -1838,15 +1839,15 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { if (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half) { NVF_ERROR( - group_size == 8, - "Group size should be 8 for " + group_size == 8 || group_size == 4 || group_size == 2, + "Group size should be 2, 4 or 8 for " "BlockQuantizationOp: ", bqop->toString()); } else { NVF_ERROR( - group_size == 4, - "Group size should be 4 for " + group_size == 4 || group_size == 2, + "Group size should be 2 or 4 for " "BlockQuantizationOp: ", bqop->toString()); } diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 887d11c09c7..1f3e0a8c84e 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -428,8 +428,7 @@ class ExprValidator : public OptOutDispatch { // This runtime function expects the inputs to be in local memory. The // quantized output will also be in local memory, but the block scaling // factors will be written out to global memory. The device runtime currently - // works on 4 elements per thread (8 for bf16/fp16) - this will be expanded - // later to support 2, 4, and 8 (only bf16/fp16) per thread. The runtime + // works on 2/4 elements per thread (also 8 for bf16/fp16). The runtime // function is based on a parallelization scheme that expects TIDx and BIDx, // and optionally TIDy and BIDy. 3D parallelization is not supported. Based // on the above, we have the following basic validation checks: @@ -437,7 +436,7 @@ class ExprValidator : public OptOutDispatch { // Input is in local memory. // Block scaling factor is in global memory and // quantized output is in local memory. - // The Group ID has an extent of 4/8 depending on the data + // The Group ID has an extent of 2/4/8 depending on the data // type. // There are no TIDz/BIDz IDs. We don't support 3D parallelization here. @@ -566,11 +565,12 @@ class ExprValidator : public OptOutDispatch { auto input_dtype = inp_tv->dtype(); NVF_ERROR( - (inner_extent == 4 && input_dtype == DataType::Float) || - (inner_extent == 8 && + ((inner_extent == 4 || inner_extent == 2) && + input_dtype == DataType::Float) || + ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half)), - "The vectorized/grouped dimension must be 4 (FP32) or 8 " + "Thegroup dimension must be 2/4 (FP32) or 2/4/8 " "(BF16). Found: ", inner_extent, ". Expr: ", diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 788372533de..59ef21c8e80 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,53 +9,34 @@ namespace nvf { namespace bq { -// This helper function is templatized of over types float, __half, and -// __bfloat. This assumes that for float, each thread was working on 4 elements. -// Thus 4 threads were working to find the max of 16 elements, and hence we need -// two steps to find the maximum. If the type is __bfloat or __half, then we -// only need a single step to find the maximum of 16 elements as each thread was -// working on 8 elements and 2 threads are required to compute the max of 16 -// elements. -// This function assumes for float each thread has already computed the max of 4 -// elements (8 elements for the other 2 data types) and the block size is 16, so -// we have 4 threads (2 for bf16/fp16) participating in the reduction. -// TODO: For FP32 support the cases where each thread works on 2 or 4 elements. -// TODO: For bf16/fp16 support the cases where each thread works on 2,4, or 8 -// elements. -template +template __device__ __inline__ void reduceAcrossThreads(float& per_thread_computed_max) { // The mask 0xffffffff indicates all 32 threads in the warp are participating. unsigned int mask = 0xffffffff; - // --- Reduction Step 1 --- - // Exchange and compare with thread 2 lanes away within the quad. - // e.g., thread 0 exchanges with 2; thread 1 with 3. - // The XOR pattern naturally keeps the operation within each quad. - if (std::is_same::value) { + // Perform reduction across threads in log2(NUM_ELEMENTS) stages + // The reduction happens by progressively halving the distance between + // threads that exchange values using XOR shuffle. + // For NUM_ELEMENTS=8 (e.g., ITEMS_PER_THREAD=2): 3 stages (XOR with 4, 2, 1) + // For NUM_ELEMENTS=4 (e.g., ITEMS_PER_THREAD=4): 2 stages (XOR with 2, 1) + // For NUM_ELEMENTS=2 (e.g., ITEMS_PER_THREAD=8): 1 stage (XOR with 1) +#pragma unroll + for (int offset = NUM_ELEMENTS / 2; offset > 0; offset /= 2) { per_thread_computed_max = fmax( per_thread_computed_max, - __shfl_xor_sync(mask, per_thread_computed_max, 2)); + __shfl_xor_sync(mask, per_thread_computed_max, offset)); } - // --- Reduction Step 2 --- - // Exchange and compare with thread 1 lane away. - // e.g., thread 0 exchanges with 1; thread 2 with 3. - per_thread_computed_max = fmax( - per_thread_computed_max, - __shfl_xor_sync(mask, per_thread_computed_max, 1)); - - // At this point, all threads in a quad hold the maximum value for that - // quad(pair of 2 threads). + // At this point, all threads involved hold the maximum value for the + // (quantization) block. } // A runtime function to compute quantized nvfp4 output (output) and fp8 block // scaling (block_scales) factors from fp32, fp16, bf16 inputs (input). // The function is templatized over input type T (float, __half, __bfloat). -// This function assumes that for float, each thread is working on 4 elements. -// Thus 4 threads are working to quantize 16 elements. If the type is __bfloat -// or -// __half, then 2 threads are working to quantize 16 elements as each thread -// is working on 8 elements. +// This function assumes that for float, each thread is working on 2, 4 or 8 +// elements (ITEMS_PER_THREAD). Thus n threads are working to quantize 16 +// elements, where n = 16 / ITEMS_PER_THREAD. template < int ITEMS_PER_THREAD, typename T, @@ -76,8 +57,10 @@ __device__ void block_quantize_to_nvfp4( "Input type must be float, __half or __bfloat"); static_assert( - (is_float && ITEMS_PER_THREAD == 4) || - (is_half_or_bfloat && ITEMS_PER_THREAD == 8), + (is_float && (ITEMS_PER_THREAD == 4 || ITEMS_PER_THREAD == 2)) || + (is_half_or_bfloat && + (ITEMS_PER_THREAD == 8 || ITEMS_PER_THREAD == 4 || + ITEMS_PER_THREAD == 2)), "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat or __half " "type"); @@ -107,7 +90,8 @@ __device__ void block_quantize_to_nvfp4( // Compute the max accross 4 threads (float) or 2 threads (bf16/fp16) // This assumes each thread has already computed is local max of 4 (fp32) or // 8 (bf16/fp16) elements. - reduceAcrossThreads(local_max); + constexpr int NUM_ELEMENTS = 16 / ITEMS_PER_THREAD; + reduceAcrossThreads(local_max); float block_max = local_max; // This division should be replaced with a multiplication diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 727b20c56d1..212dd03e1ef 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -176,11 +176,13 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { HeuristicIs(SchedulerType::InnerPersistent))); } -class BlockQuantizationTest : public BlackwellBase, - public ::testing::WithParamInterface {}; +class BlockQuantizationTest + : public BlackwellBase, + public ::testing::WithParamInterface> {}; TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { - auto data_hp_dtype = GetParam(); + auto data_hp_dtype = std::get<0>(GetParam()); + auto group_width = std::get<1>(GetParam()); // Baseline implementation std::unique_ptr fusion = std::make_unique(); @@ -221,7 +223,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { fusion_new_op->addOutput(quantization_results.block_scales); fusion_new_op->addOutput(t_out); - auto vectorization_factor = data_hp_dtype == DataType::Float ? 4 : 8; + auto vectorization_factor = group_width; for (auto t : {tv_data_hp, @@ -289,7 +291,8 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { } TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { - auto data_hp_dtype = GetParam(); + auto data_hp_dtype = std::get<0>(GetParam()); + auto group_width = std::get<1>(GetParam()); // Baseline implementation std::unique_ptr fusion = std::make_unique(); @@ -335,7 +338,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { t0->setMemoryType(MemoryType::Local); - auto vectorization_factor = data_hp_dtype == DataType::Float ? 4 : 8; + auto vectorization_factor = group_width; for (auto t : {tv_data_hp, @@ -836,7 +839,19 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( , BlockQuantizationTest, - ::testing::Values(DataType::BFloat16, DataType::Float, DataType::Half), - testing::PrintToStringParamName()); + ::testing::Values( + std::make_tuple(DataType::Float, 2), + std::make_tuple(DataType::Float, 4), + std::make_tuple(DataType::BFloat16, 2), + std::make_tuple(DataType::BFloat16, 4), + std::make_tuple(DataType::BFloat16, 8), + std::make_tuple(DataType::Half, 2), + std::make_tuple(DataType::Half, 4), + std::make_tuple(DataType::Half, 8)), + [](const testing::TestParamInfo>& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "_GroupWidth" << std::get<1>(info.param); + return os.str(); + }); } // namespace nvfuser From 9afefd150b2fa85ca05c07858c8f77a0334cb65e Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 6 Nov 2025 06:26:32 -0800 Subject: [PATCH 67/79] update comments --- csrc/device_lower/validation.cpp | 2 +- runtime/block_quantization_kernels.cu | 11 +++++++---- tests/cpp/test_low_precision_recipe.cpp | 5 +++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 1f3e0a8c84e..889d3f0033c 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -570,7 +570,7 @@ class ExprValidator : public OptOutDispatch { ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half)), - "Thegroup dimension must be 2/4 (FP32) or 2/4/8 " + "The group dimension must be 2/4 (FP32) or 2/4/8 " "(BF16). Found: ", inner_extent, ". Expr: ", diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 59ef21c8e80..232c466e99a 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,6 +9,8 @@ namespace nvf { namespace bq { +// This helper function finds the max of NUM_ELEMENTS (2, 4, or 8) values +// using the same number of threads. template __device__ __inline__ void reduceAcrossThreads(float& per_thread_computed_max) { // The mask 0xffffffff indicates all 32 threads in the warp are participating. @@ -61,7 +63,8 @@ __device__ void block_quantize_to_nvfp4( (is_half_or_bfloat && (ITEMS_PER_THREAD == 8 || ITEMS_PER_THREAD == 4 || ITEMS_PER_THREAD == 2)), - "ITEMS_PER_THREAD must be 4 for float type or 8 for __bfloat or __half " + "ITEMS_PER_THREAD must be 2, 4 for float type or 2, 4, or 8 for __bfloat " + "or __half " "type"); // Number of threads involved in computing one block scaling factor @@ -87,9 +90,9 @@ __device__ void block_quantize_to_nvfp4( local_max = fmax(local_max, fabsf(vec_in[i])); } - // Compute the max accross 4 threads (float) or 2 threads (bf16/fp16) - // This assumes each thread has already computed is local max of 4 (fp32) or - // 8 (bf16/fp16) elements. + // Compute the max accross 16/ITEMS_PER_THREAD threads + // This assumes each thread has already computed is local max of 2, 4 (fp32) + // or 2,4, 8 (bf16/fp16) elements. constexpr int NUM_ELEMENTS = 16 / ITEMS_PER_THREAD; reduceAcrossThreads(local_max); float block_max = local_max; diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 212dd03e1ef..2be12491a2c 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -237,7 +237,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { t->merge(-2); } - // split by 4 (or 8). + // split by 4 (or 2, 8). // I -> I/4, 4 t->split(-1, vectorization_factor); // I//4, 4 -> I/4, 1, 4 @@ -346,7 +346,8 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { quantization_results.quantized_tensor, quantization_results.block_scales, t_out}) { - // (m, n) -> (m, n/4, 4) (or (m, n/8, 8) if bfloat16) + // We split by 4 as an example, but can also be 2 or 8(fp16/bf16 on;y) + // (m, n) -> (m, n/4, 4) // (m, n/4, 4) -> (m, n/128, 32, 4) t->split(-1, vectorization_factor); // V t->split(-2, 32); // BDx From 6a6b46dee17b6796fd8e5d773fc707370062004f Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 6 Nov 2025 08:28:45 -0800 Subject: [PATCH 68/79] extend pointwise scheduler to accept block quantization op --- csrc/scheduler/pointwise.cpp | 50 ++++++++++++ csrc/scheduler/registry_utils.cpp | 15 ++++ csrc/scheduler/registry_utils.h | 4 + csrc/scheduler/tools/domain_map.cpp | 12 ++- csrc/scheduler/utils.cpp | 8 +- tests/cpp/test_low_precision_recipe.cpp | 101 ++++++++++++++++++++++++ 6 files changed, 188 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 7ff7932211c..bc87ce691f4 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -524,6 +524,16 @@ std::unique_ptr getPointwiseHeuristics( divisible_split, vectorizable_inputs_outputs_entry.get()); + // Check if fusion has BlockQuantizationOp + auto fusion_has_block_quantization = + ir_utils::getOpsOfType(fusion).size() > 0; + + // Limit unroll factor for fusions with BlockQuantizationOp. The runtime + // function which implements quantization assumes no unrolling + if (fusion_has_block_quantization && unroll_factor > 1) { + unroll_factor = 1; + } + if (is_outer_broadcast_dominated) { params->unroll_factor_outer = unroll_factor; } else { @@ -827,6 +837,17 @@ bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + // The block scales output of the Block Quantization Op + // should be a segment output as it is written to the global + // memory. + if (registry_utils::hasNonTerminalBlockQuantizeOp(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "no support for block quantization where block scales is not a fusion " + "output"); + return false; + } + return true; } @@ -1219,6 +1240,19 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 2; } + auto bq_ops = ir_utils::getOpsOfType(fusion); + std::vector nvfp4_quantized_outputs = {}; + for (auto bq_op : bq_ops) { + nvfp4_quantized_outputs.push_back( + bq_op->quantizedOutput()->as()); + } + + if (bq_ops.size() > 0 && pparams->vectorization_factor < 2) { + NVF_THROW( + "Unable to schedule BlockQuantization since we were not able to " + "vectorize the reference tensor"); + } + TransformPropagator propagator(reference_tv); MaxLogicalDomainInfoSpanningTree spanning_tree(reference_tv); spanning_tree.traverse(&propagator); @@ -1255,6 +1289,13 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } } } + + // Vectorize nvfp4 quantized outputs. + // We will later change the vectorized ID to group ID + for (auto quantized_output : nvfp4_quantized_outputs) { + vectorized_tvs.emplace_back(quantized_output); + } + if (!vectorized_tvs.empty()) { // Aggressively mark with vectorized and cleanup later. That way we // don't have to manually specify parallelization outside the reference. @@ -1267,6 +1308,15 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } } + // Change vectorized IDs to group IDs for quantized outputs + for (auto quantized_output : nvfp4_quantized_outputs) { + for (auto id : quantized_output->getLoopDomain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + id->parallelize(ParallelType::Group); + } + } + } + // Begin by inlining at the unswitch position for the entire DAG. The cached // inputs, and outputs will keep this inline position, but other tensors will // get a higher position in later inline propagation. We need this separate diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index a82500e7411..2bb309695df 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -840,6 +840,21 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( return false; } +// Returns true if the output of the block quantization op +// is not the fusion/segment output. +bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { + for (auto expr : fusion->exprs()) { + if (expr->isA()) { + auto block_scales = + expr->as()->blockScales()->as(); + if (!block_scales->isFusionOutput()) { + return true; + } + } + } + return false; +} + // Checks if any broadcasts are resolved after a reduction, this shouldn't be // accepted in the single reduction or multi-reduction scheduler bool SchedulerTopologyChecker::hasPostReductionBCast(Fusion* fusion) { diff --git a/csrc/scheduler/registry_utils.h b/csrc/scheduler/registry_utils.h index 968047b5c7e..f3d5dcf711c 100644 --- a/csrc/scheduler/registry_utils.h +++ b/csrc/scheduler/registry_utils.h @@ -72,6 +72,10 @@ PrimDataType getIndexTypeOfKernel( const KernelArgumentHolder& inputs, ExpressionEvaluator& ee); +// Check if the block scales output of Block Quantization Op +// is a segment output. +bool hasNonTerminalBlockQuantizeOp(Fusion* fusion); + class SchedulerTopologyChecker { public: // Checks if any broadcasts are resolved after a reduction that don't follow diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index c36b1a6de22..7610a8e16b6 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -409,7 +409,17 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { // no need to check for self. - if (output_tv == tv) { + // If this is the block scaling factor ouptut of a BlockQuantizationOp, + // then we skip the check as we only consider the quantized ouptut of the + // BlockQuantizationOp when looking for a reference tensor. This is because + // the two outputs of block quantization op are not symmtetrcal and the + // logical domains of the scaling factor is not completely mapped. + if (output_tv == tv || + (output_tv->definition()->isA() && + output_tv == + output_tv->definition() + ->as() + ->blockScales())) { continue; } if (!areAllTargetIdsCoveredBy(output_tv, tv)) { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index ac422bc252d..2f0212ed745 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1380,8 +1380,14 @@ std::vector> cacheAndForkOutputs( // the output of ScatterOp must on the global memory due to the random // or atomic access. Similarly, PreprocessGroupedMatmulInputSf requires // direct write to global memory because of random access. + // The output of block quantization has to be in global memory. This is + // because this op is implemented via a runtime function that write the + // scaling factors to global memory. output->definition() - ->isOneOf()) { + ->isOneOf() || + (output->definition()->isA() && + output->definition()->as()->blockScales() == + output)) { continue; } if (!output->uses().empty()) { diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 2be12491a2c..b483d67f21f 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -689,6 +689,87 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { "IDs from the logical domain for BlockQuantizationOp"))); } +struct BlockQuantizationSchedulingTestParams { + DataType data_type; + int m; + int n; +}; + +class BlockQuantizationSchedulingTest + : public BlackwellBase, + public ::testing::WithParamInterface< + BlockQuantizationSchedulingTestParams> {}; + +TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { + auto params = GetParam(); + auto data_type = params.data_type; + const int m = params.m; + const int n = params.n; + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + createNVFP4QunatizationFusion(fusion.get(), data_type); + + FusionExecutorCache fec(std::move(fusion)); + + std::vector inputs; + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)) + .to(data_type_to_aten(data_type))); + auto outputs_baseline = fec.runFusionWithInputs(inputs); + + auto baseline_block_scales = outputs_baseline[0].as(); + auto baseline_quantized_tensor = outputs_baseline[1].as(); + + auto baseline_block_scales_cpu = baseline_block_scales.cpu(); + auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); + + const uint8_t* baseline_block_scales_data = + static_cast(baseline_block_scales_cpu.data_ptr()); + const uint8_t* baseline_quantized_data = + static_cast(baseline_quantized_tensor_cpu.data_ptr()); + + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + + auto tv_in_1 = makeContigTensor(2, data_type); + fusion_new_op->addInput(tv_in_1); + + auto quantization_results = blockQuantize(tv_in_1); + + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput(quantization_results.quantized_tensor); + + FusionExecutorCache executor_cache(std::move(fusion_new_op)); + auto outputs_new_op = executor_cache.runFusionWithInputs(inputs); + + // Verify we got the expected outputs + auto block_scales_output = outputs_new_op[0].as(); + auto quantized_tensor_output = outputs_new_op[1].as(); + + // Move tensors from GPU to CPU + auto block_scales_cpu = block_scales_output.cpu(); + auto quantized_tensor_cpu = quantized_tensor_output.cpu(); + + auto block_scales_bytes = (m * n) / block_size; + auto quantized_tensor_bytes = (m * n) / 2; + + const uint8_t* block_scales_data = + static_cast(block_scales_cpu.data_ptr()); + for (int i = 0; i < block_scales_bytes; ++i) { + EXPECT_EQ( + block_scales_data[i], + baseline_block_scales_data[i]); // Compare with baseline + } + + const uint8_t* quantized_data = + static_cast(quantized_tensor_cpu.data_ptr()); + for (int i = 0; i < quantized_tensor_bytes; ++i) { + EXPECT_EQ( + quantized_data[i], + baseline_quantized_data[i]); // Compare with baseline + } +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); @@ -855,4 +936,24 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +INSTANTIATE_TEST_SUITE_P( + , + BlockQuantizationSchedulingTest, + ::testing::Values( + BlockQuantizationSchedulingTestParams{DataType::Float, 1024, 1024}, + BlockQuantizationSchedulingTestParams{DataType::Float, 128, 64}, + BlockQuantizationSchedulingTestParams{DataType::Float, 2048, 128}, + BlockQuantizationSchedulingTestParams{DataType::Float, 2048, 2048}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 1024, 1024}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 128, 64}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 2048, 128}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 2048, 2048}), + [](const testing::TestParamInfo& + info) { + std::ostringstream name; + name << info.param.data_type << "_" << info.param.m << "x" + << info.param.n; + return name.str(); + }); + } // namespace nvfuser From 5cb3e50336c2454ad55584654dbda011c3af0fb2 Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 6 Nov 2025 12:26:18 -0800 Subject: [PATCH 69/79] edit comments --- csrc/codegen.cpp | 4 ++-- runtime/block_quantization_kernels.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index c4639d31309..a27fd3dc0b1 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1817,8 +1817,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // This operator is plumbed down to a runtime function call. // One of the assumptions is that the device runtime expects // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, - // 4, or 8 for Half. We achieve this by having the input tv scheduler to - // have the inner dimension grouped by 4/8. + // 4, or 8 for Half. We achieve this by having the quantized output tv + // scheduled to have the inner dimension grouped by 2/4/8. auto output = bqop->quantizedOutput()->as()->view(); int64_t group_size = 1; diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index 232c466e99a..d941fa82e05 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -36,7 +36,7 @@ __device__ __inline__ void reduceAcrossThreads(float& per_thread_computed_max) { // A runtime function to compute quantized nvfp4 output (output) and fp8 block // scaling (block_scales) factors from fp32, fp16, bf16 inputs (input). // The function is templatized over input type T (float, __half, __bfloat). -// This function assumes that for float, each thread is working on 2, 4 or 8 +// This function assumes that for float, each thread is working on 2,4 or 8 // elements (ITEMS_PER_THREAD). Thus n threads are working to quantize 16 // elements, where n = 16 / ITEMS_PER_THREAD. template < From 63cf7b923fb1f0e562377d7b90eadf76a7439d6d Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 6 Nov 2025 13:04:05 -0800 Subject: [PATCH 70/79] address comments from greptile --- csrc/scheduler/tools/domain_map.cpp | 7 ++++--- tests/cpp/test_low_precision_recipe.cpp | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 7610a8e16b6..73e18d55583 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -410,12 +410,13 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { ir_utils::filterByType(fusion_->outputs())) { // no need to check for self. // If this is the block scaling factor ouptut of a BlockQuantizationOp, - // then we skip the check as we only consider the quantized ouptut of the + // then we skip the check as we only consider the quantized output of the // BlockQuantizationOp when looking for a reference tensor. This is because - // the two outputs of block quantization op are not symmtetrcal and the + // the two outputs of block quantization op are not symmetrical and the // logical domains of the scaling factor is not completely mapped. if (output_tv == tv || - (output_tv->definition()->isA() && + (output_tv->definition() && + output_tv->definition()->isA() && output_tv == output_tv->definition() ->as() diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index b483d67f21f..e5477b5cfb1 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -108,7 +108,7 @@ constexpr double F8E4M3_MAX = 448.0; class NVFP4QuantizeTest : public BlackwellBase, public ::testing::WithParamInterface {}; namespace { -void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { +void createNVFP4QuantizationFusion(Fusion* fusion, DataType data_hp_dtype) { auto tv_data_hp = makeContigTensor(2, data_hp_dtype); fusion->addInput(tv_data_hp); @@ -155,7 +155,7 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -187,7 +187,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -297,7 +297,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -708,7 +708,7 @@ TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_type); + createNVFP4QuantizationFusion(fusion.get(), data_type); FusionExecutorCache fec(std::move(fusion)); From 4ca4d8ea18608e38ced2c1e911dc77436bc6993e Mon Sep 17 00:00:00 2001 From: protonu Date: Tue, 11 Nov 2025 09:28:48 -0800 Subject: [PATCH 71/79] address reviewer comment - move codea around --- csrc/scheduler/pointwise.cpp | 41 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index bc87ce691f4..fabd7ab2f45 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -508,6 +508,19 @@ std::unique_ptr getPointwiseHeuristics( logical_reorder_map)); params->vectorization_factor = vectorization_factor; + // Check if fusion has BlockQuantizationOp + auto fusion_has_block_quantization = + ir_utils::getOpsOfType(fusion).size() > 0; + + // The runtime function implementing block quantization op needs + // at least 2 elements per thread to work. + if (fusion_has_block_quantization && params->vectorization_factor < 2) { + NVF_THROW( + "Unable to schedule a fusion iwth BlockQuantization since we were not " + "able to " + "vectorize by at least a factor of 2 "); + } + // get unroll factor: int64_t total_blocks = break_point > 0 @@ -524,10 +537,6 @@ std::unique_ptr getPointwiseHeuristics( divisible_split, vectorizable_inputs_outputs_entry.get()); - // Check if fusion has BlockQuantizationOp - auto fusion_has_block_quantization = - ir_utils::getOpsOfType(fusion).size() > 0; - // Limit unroll factor for fusions with BlockQuantizationOp. The runtime // function which implements quantization assumes no unrolling if (fusion_has_block_quantization && unroll_factor > 1) { @@ -1240,6 +1249,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 2; } + // We first vectorize the the quantized outputs of the block quantization ops. + // We then convert the vectorized ID to group ID. + // We do so as the runtime function for block quantization expects 2/4/8 + // elements per thread. auto bq_ops = ir_utils::getOpsOfType(fusion); std::vector nvfp4_quantized_outputs = {}; for (auto bq_op : bq_ops) { @@ -1247,12 +1260,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { bq_op->quantizedOutput()->as()); } - if (bq_ops.size() > 0 && pparams->vectorization_factor < 2) { - NVF_THROW( - "Unable to schedule BlockQuantization since we were not able to " - "vectorize the reference tensor"); - } - TransformPropagator propagator(reference_tv); MaxLogicalDomainInfoSpanningTree spanning_tree(reference_tv); spanning_tree.traverse(&propagator); @@ -1305,14 +1312,14 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (!should_vectorize_reference_tv) { vectorize_id->parallelize(ParallelType::Serial); } - } - } - // Change vectorized IDs to group IDs for quantized outputs - for (auto quantized_output : nvfp4_quantized_outputs) { - for (auto id : quantized_output->getLoopDomain()) { - if (id->getParallelType() == ParallelType::Vectorize) { - id->parallelize(ParallelType::Group); + // Change vectorized IDs to group IDs for quantized outputs + for (auto quantized_output : nvfp4_quantized_outputs) { + for (auto id : quantized_output->getLoopDomain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + id->parallelize(ParallelType::Group); + } + } } } } From 031c28006cbfdbcf21f25b73f7a817e4037b478a Mon Sep 17 00:00:00 2001 From: Protonu Date: Tue, 11 Nov 2025 12:43:25 -0500 Subject: [PATCH 72/79] Update csrc/scheduler/pointwise.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/scheduler/pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index fabd7ab2f45..839c82753e3 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -516,7 +516,7 @@ std::unique_ptr getPointwiseHeuristics( // at least 2 elements per thread to work. if (fusion_has_block_quantization && params->vectorization_factor < 2) { NVF_THROW( - "Unable to schedule a fusion iwth BlockQuantization since we were not " + "Unable to schedule a fusion with BlockQuantization since we were not " "able to " "vectorize by at least a factor of 2 "); } From 479691be8b334cc6cc4982d83a15f5b3bca966b7 Mon Sep 17 00:00:00 2001 From: protonu Date: Wed, 12 Nov 2025 09:22:33 -0800 Subject: [PATCH 73/79] correct typo --- csrc/scheduler/tools/domain_map.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 73e18d55583..55abe2d0914 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -409,7 +409,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { // no need to check for self. - // If this is the block scaling factor ouptut of a BlockQuantizationOp, + // If this is the block scaling factor output of a BlockQuantizationOp, // then we skip the check as we only consider the quantized output of the // BlockQuantizationOp when looking for a reference tensor. This is because // the two outputs of block quantization op are not symmetrical and the From 6a5a553c99feca886b998781a3ee0e865752f9da Mon Sep 17 00:00:00 2001 From: protonu Date: Thu, 13 Nov 2025 12:46:43 -0800 Subject: [PATCH 74/79] move checks to canScheduleRunTime --- csrc/scheduler/pointwise.cpp | 53 +++++++++++++++++++------ csrc/scheduler/pointwise.h | 4 +- tests/cpp/test_low_precision_recipe.cpp | 53 +++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 16 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 839c82753e3..3d88550b456 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -508,19 +508,6 @@ std::unique_ptr getPointwiseHeuristics( logical_reorder_map)); params->vectorization_factor = vectorization_factor; - // Check if fusion has BlockQuantizationOp - auto fusion_has_block_quantization = - ir_utils::getOpsOfType(fusion).size() > 0; - - // The runtime function implementing block quantization op needs - // at least 2 elements per thread to work. - if (fusion_has_block_quantization && params->vectorization_factor < 2) { - NVF_THROW( - "Unable to schedule a fusion with BlockQuantization since we were not " - "able to " - "vectorize by at least a factor of 2 "); - } - // get unroll factor: int64_t total_blocks = break_point > 0 @@ -539,6 +526,10 @@ std::unique_ptr getPointwiseHeuristics( // Limit unroll factor for fusions with BlockQuantizationOp. The runtime // function which implements quantization assumes no unrolling + // Check if fusion has BlockQuantizationOp + auto fusion_has_block_quantization = + ir_utils::getOpsOfType(fusion).size() > 0; + if (fusion_has_block_quantization && unroll_factor > 1) { unroll_factor = 1; } @@ -860,6 +851,42 @@ bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { return true; } +bool PointWiseScheduler::canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) { + FUSER_PERF_SCOPE("PointWiseScheduler::canScheduleRunTime"); + // Check if the fusion has a Block Quantization Op + // If so, ensure that the vectorization factor is at least 2 + // and that the grid y dimension is not split. + // These are requirements of the current implementation of the + // Block Quantization Op runtime function. + auto has_block_quant_op = + !ir_utils::getOpsOfType(fusion).empty(); + if (has_block_quant_op) { + auto pparams = getPointwiseHeuristics(fusion, runtime_info, data_cache); + NVF_ERROR(pparams != nullptr); + if (pparams->vectorization_factor < 2) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Block Quantization Op requires vectorization factor to be at least " + "2."); + return false; + } + + if (pparams->split_grid_y_dim) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Block Quantization Op is not supported when splitting grid y " + "dimension. This is because this will create a serial ID with an " + "extent > 1. The runtime function implementing block quantization " + "will currently not be able to handle that."); + return false; + } + } + return true; +} + // TODO: Inline intermediate operations (avoid inlining unrolled/vectorized // input/output caches) void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { diff --git a/csrc/scheduler/pointwise.h b/csrc/scheduler/pointwise.h index 9326689f4eb..daba3c464ca 100644 --- a/csrc/scheduler/pointwise.h +++ b/csrc/scheduler/pointwise.h @@ -164,9 +164,7 @@ class PointWiseScheduler : public SchedulerEntry { bool canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - HeuristicDataCache* data_cache = nullptr) override { - return true; - } + HeuristicDataCache* data_cache = nullptr) override; std::unique_ptr computeHeuristics( Fusion* fusion, diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index e5477b5cfb1..c8e05104ef3 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -770,6 +770,59 @@ TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { } } +class BlockQuantizationCanScheduleTests : public BlackwellBase {}; + +TEST_F( + BlockQuantizationCanScheduleTests, + CanRuntimeScheduleFailFromNoVectorization) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion->addInput(tv_data_hp); + + auto t0 = set(tv_data_hp); + auto quantization_results = blockQuantize(t0); + auto t_out = set(quantization_results.quantized_tensor); + + fusion->addOutput(quantization_results.block_scales); + fusion->addOutput(t_out); + + // Create misaligned tensor directly on GPU using custom CUDA allocation + size_t element_size = 4; + int m = 1024; + int n = 1024; + + size_t total_elements = m * n; + size_t buffer_size = + total_elements * element_size + 16; // Extra bytes for misalignment + + // Allocate GPU memory with extra space + void* gpu_ptr; + cudaMalloc(&gpu_ptr, buffer_size); + + // Create tensor from GPU memory at offset of 4 bytes + void* misaligned_ptr = static_cast(gpu_ptr) + 4; + auto misaligned_gpu_tensor = at::from_blob( + misaligned_ptr, + {m, n}, + at::TensorOptions() + .dtype(data_type_to_aten(DataType::Float)) + .device(at::kCUDA)); + + auto good_input = at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)); + + // Expect failure as the input tensor can't be vectorized + // and we need vectorization > 2 + SchedulerRuntimeInfo runtime_info(fusion.get(), {misaligned_gpu_tensor}); + ASSERT_FALSE(Schedule::canSchedule( + SchedulerType::PointWise, fusion.get(), runtime_info)); + + SchedulerRuntimeInfo runtime_info_new(fusion.get(), {good_input}); + ASSERT_TRUE(Schedule::canSchedule( + SchedulerType::PointWise, fusion.get(), runtime_info_new)); +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); From 15140228016ff4d6f503ae2cdbb984703f01aa4d Mon Sep 17 00:00:00 2001 From: Protonu Date: Thu, 13 Nov 2025 15:54:37 -0500 Subject: [PATCH 75/79] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/scheduler/pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 3d88550b456..7c5b1c8764c 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -1276,7 +1276,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 2; } - // We first vectorize the the quantized outputs of the block quantization ops. + // We first vectorize the quantized outputs of the block quantization ops. // We then convert the vectorized ID to group ID. // We do so as the runtime function for block quantization expects 2/4/8 // elements per thread. From 1e92c4d5047a4fe0b71a0846cff5e19cdbd237a4 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 14 Nov 2025 10:16:39 -0800 Subject: [PATCH 76/79] cache check for BQ ops --- csrc/scheduler/compile_time_info.h | 13 ++++++++++++- csrc/scheduler/pointwise.cpp | 26 ++++++++++++++++++++------ csrc/scheduler/registry.cpp | 2 ++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index fdb17478dda..2a60f253bb7 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -48,7 +48,8 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, - VECTORIZATION_BREAK_POINT_OF_RED_PROD + VECTORIZATION_BREAK_POINT_OF_RED_PROD, + HAS_BLOCK_QUANTIZATION_OPS }; //! Entry type definition class for `DOMAIN_MAP`, @@ -142,6 +143,16 @@ class ReductionTVs { CompileTimeEntryType::REDUCTION_TVS; }; +//! Entry type definition class for `HAS_BLOCK_QUANTIZATION_OPS`, +//! stores a boolean flag indicating whether the fusion contains any +//! BlockQuantizationOp operations. +class HasBlockQuantizationOps { + public: + using DataType = bool; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::HAS_BLOCK_QUANTIZATION_OPS; +}; + //! Entry type definition class for `PERSISTENT_BUFFER_INFO`, //! stores persistent buffers inferred from topology and scheduling of fusion. class PersistentBufferInfo { diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 7c5b1c8764c..54a438c0326 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -527,10 +527,16 @@ std::unique_ptr getPointwiseHeuristics( // Limit unroll factor for fusions with BlockQuantizationOp. The runtime // function which implements quantization assumes no unrolling // Check if fusion has BlockQuantizationOp - auto fusion_has_block_quantization = - ir_utils::getOpsOfType(fusion).size() > 0; + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty()); + }) + .get(); - if (fusion_has_block_quantization && unroll_factor > 1) { + if (has_block_quantization_ops && unroll_factor > 1) { unroll_factor = 1; } @@ -861,9 +867,17 @@ bool PointWiseScheduler::canScheduleRunTime( // and that the grid y dimension is not split. // These are requirements of the current implementation of the // Block Quantization Op runtime function. - auto has_block_quant_op = - !ir_utils::getOpsOfType(fusion).empty(); - if (has_block_quant_op) { + + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty()); + }) + .get(); + + if (has_block_quantization_ops) { auto pparams = getPointwiseHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(pparams != nullptr); if (pparams->vectorization_factor < 2) { diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 01f15de0554..e16c23d7814 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -280,6 +280,8 @@ template class HeuristicDataCacheEntry< template class HeuristicDataCacheEntry< HeuristicCompileTime::UnrollableInputsAndOutputs>; template class HeuristicDataCacheEntry; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::HasBlockQuantizationOps>; template class HeuristicDataCacheEntry< HeuristicCompileTime::PersistentBufferInfo>; template class HeuristicDataCacheEntry< From 7d7458317f192c041c390163b5d011080d8795a1 Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 14 Nov 2025 10:25:47 -0800 Subject: [PATCH 77/79] move unroll factor computation bypass --- csrc/scheduler/pointwise.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 54a438c0326..dbe2e5ce96c 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -141,7 +141,11 @@ int64_t getUnrollFactor( int64_t total_blocks, int64_t vectorization_bits, bool divisible_split, - std::vector vectorizable_io_tvs) { + std::vector vectorizable_io_tvs, + bool has_block_quantization_ops) { + if (has_block_quantization_ops) { + return 1; + } // only consider vectorizable inputs, // needs to check if it's already in the list to avoid duplication since a tv // may be both input and output, e.g. NVFuserTest.FusionIssue2372_CUDA @@ -516,17 +520,10 @@ std::unique_ptr getPointwiseHeuristics( bool divisible_split = break_point > 0 ? (right_elem_count % (params->vectorization_factor * bdimx) == 0) : (n_elems % (params->vectorization_factor * kThreadX) == 0); - int64_t unroll_factor = getUnrollFactor( - fusion, - break_point, - total_blocks, - params->vectorization_factor * max_dtype_size_bit_for_vectorization, - divisible_split, - vectorizable_inputs_outputs_entry.get()); + // Check if fusion has BlockQuantizationOp // Limit unroll factor for fusions with BlockQuantizationOp. The runtime // function which implements quantization assumes no unrolling - // Check if fusion has BlockQuantizationOp auto has_block_quantization_ops = HeuristicDataCacheEntry( data_cache, @@ -536,6 +533,15 @@ std::unique_ptr getPointwiseHeuristics( }) .get(); + int64_t unroll_factor = getUnrollFactor( + fusion, + break_point, + total_blocks, + params->vectorization_factor * max_dtype_size_bit_for_vectorization, + divisible_split, + vectorizable_inputs_outputs_entry.get(), + has_block_quantization_ops); + if (has_block_quantization_ops && unroll_factor > 1) { unroll_factor = 1; } From 4a0d7086eb129cde2aadae5e044bf352fba797ac Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 14 Nov 2025 10:29:04 -0800 Subject: [PATCH 78/79] cleanup redundant code --- csrc/scheduler/pointwise.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index dbe2e5ce96c..afd211296e1 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -521,8 +521,8 @@ std::unique_ptr getPointwiseHeuristics( ? (right_elem_count % (params->vectorization_factor * bdimx) == 0) : (n_elems % (params->vectorization_factor * kThreadX) == 0); - // Check if fusion has BlockQuantizationOp - // Limit unroll factor for fusions with BlockQuantizationOp. The runtime + // Check if fusion has BlockQuantizationOp(s) + // Limit unroll factor for fusions with BlockQuantizationOp(s). The runtime // function which implements quantization assumes no unrolling auto has_block_quantization_ops = HeuristicDataCacheEntry( @@ -542,10 +542,6 @@ std::unique_ptr getPointwiseHeuristics( vectorizable_inputs_outputs_entry.get(), has_block_quantization_ops); - if (has_block_quantization_ops && unroll_factor > 1) { - unroll_factor = 1; - } - if (is_outer_broadcast_dominated) { params->unroll_factor_outer = unroll_factor; } else { From 6f31bd4689b389e810b3522f38f5fc9348bd77fb Mon Sep 17 00:00:00 2001 From: protonu Date: Fri, 14 Nov 2025 11:57:49 -0800 Subject: [PATCH 79/79] move data_cache access to getUnroll --- csrc/scheduler/pointwise_non_tma.cpp | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/scheduler/pointwise_non_tma.cpp b/csrc/scheduler/pointwise_non_tma.cpp index 10cfe257ce3..c1e45ec94b6 100644 --- a/csrc/scheduler/pointwise_non_tma.cpp +++ b/csrc/scheduler/pointwise_non_tma.cpp @@ -138,12 +138,25 @@ int64_t getUnrollFactor( int64_t vectorization_bits, bool divisible_split, std::vector vectorizable_io_tvs, - bool has_block_quantization_ops) { + HeuristicDataCache* data_cache) { + // Check if fusion has BlockQuantizationOp(s) + // Limit unroll factor for fusions with BlockQuantizationOp(s). The runtime + // function which implements quantization assumes no unrolling + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty()); + }) + .get(); + if (has_block_quantization_ops) { // Runtime function implementing Block Quantization Op requires unroll // factor to be 1 return 1; } + // only consider vectorizable inputs, // needs to check if it's already in the list to avoid duplication since a tv // may be both input and output, e.g. NVFuserTest.FusionIssue2372_CUDA @@ -518,19 +531,6 @@ std::unique_ptr getPointwiseHeuristics( bool divisible_split = break_point > 0 ? (right_elem_count % (params->vectorization_factor * bdimx) == 0) : (n_elems % (params->vectorization_factor * kThreadX) == 0); - - // Check if fusion has BlockQuantizationOp(s) - // Limit unroll factor for fusions with BlockQuantizationOp(s). The runtime - // function which implements quantization assumes no unrolling - auto has_block_quantization_ops = - HeuristicDataCacheEntry( - data_cache, - [fusion]() { - return std::make_unique( - !ir_utils::getOpsOfType(fusion).empty()); - }) - .get(); - int64_t unroll_factor = getUnrollFactor( fusion, break_point, @@ -538,7 +538,7 @@ std::unique_ptr getPointwiseHeuristics( params->vectorization_factor * max_dtype_size_bit_for_vectorization, divisible_split, vectorizable_inputs_outputs_entry.get(), - has_block_quantization_ops); + data_cache); if (is_outer_broadcast_dominated) { params->unroll_factor_outer = unroll_factor;