diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index fdb17478dda..2a60f253bb7 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -48,7 +48,8 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, - VECTORIZATION_BREAK_POINT_OF_RED_PROD + VECTORIZATION_BREAK_POINT_OF_RED_PROD, + HAS_BLOCK_QUANTIZATION_OPS }; //! Entry type definition class for `DOMAIN_MAP`, @@ -142,6 +143,16 @@ class ReductionTVs { CompileTimeEntryType::REDUCTION_TVS; }; +//! Entry type definition class for `HAS_BLOCK_QUANTIZATION_OPS`, +//! stores a boolean flag indicating whether the fusion contains any +//! BlockQuantizationOp operations. +class HasBlockQuantizationOps { + public: + using DataType = bool; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::HAS_BLOCK_QUANTIZATION_OPS; +}; + //! Entry type definition class for `PERSISTENT_BUFFER_INFO`, //! stores persistent buffers inferred from topology and scheduling of fusion. class PersistentBufferInfo { diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 3c4e0251e75..bb1f35c827a 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -268,6 +268,62 @@ bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + // The block scales output of the Block Quantization Op + // should be a segment output as it is written to the global + // memory. + if (registry_utils::hasNonTerminalBlockQuantizeOp(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "no support for block quantization where block scales is not a fusion " + "output"); + return false; + } + + return true; +} + +bool PointWiseScheduler::canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) { + FUSER_PERF_SCOPE("PointWiseScheduler::canScheduleRunTime"); + // Check if the fusion has a Block Quantization Op + // If so, ensure that the vectorization factor is at least 2 + // and that the grid y dimension is not split. + // These are requirements of the current implementation of the + // Block Quantization Op runtime function. + + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty()); + }) + .get(); + + if (has_block_quantization_ops) { + auto heuristics = computeHeuristics(fusion, runtime_info, data_cache); + auto pparams = static_cast(heuristics.get()); + NVF_ERROR(pparams != nullptr); + if (pparams->vectorization_factor < 2) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Block Quantization Op requires vectorization factor to be at least " + "2."); + return false; + } + + if (pparams->split_grid_y_dim) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Block Quantization Op is not supported when splitting grid y " + "dimension. This is because this will create a serial ID with an " + "extent > 1. The runtime function implementing block quantization " + "will currently not be able to handle that."); + return false; + } + } return true; } diff --git a/csrc/scheduler/pointwise.h b/csrc/scheduler/pointwise.h index 9326689f4eb..daba3c464ca 100644 --- a/csrc/scheduler/pointwise.h +++ b/csrc/scheduler/pointwise.h @@ -164,9 +164,7 @@ class PointWiseScheduler : public SchedulerEntry { bool canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - HeuristicDataCache* data_cache = nullptr) override { - return true; - } + HeuristicDataCache* data_cache = nullptr) override; std::unique_ptr computeHeuristics( Fusion* fusion, diff --git a/csrc/scheduler/pointwise_non_tma.cpp b/csrc/scheduler/pointwise_non_tma.cpp index 6de4e569564..9572dd97032 100644 --- a/csrc/scheduler/pointwise_non_tma.cpp +++ b/csrc/scheduler/pointwise_non_tma.cpp @@ -137,7 +137,26 @@ int64_t getUnrollFactor( int64_t total_blocks, int64_t vectorization_bits, bool divisible_split, - std::vector vectorizable_io_tvs) { + std::vector vectorizable_io_tvs, + HeuristicDataCache* data_cache) { + // Check if fusion has BlockQuantizationOp(s) + // Limit unroll factor for fusions with BlockQuantizationOp(s). The runtime + // function which implements quantization assumes no unrolling + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty()); + }) + .get(); + + if (has_block_quantization_ops) { + // Runtime function implementing Block Quantization Op requires unroll + // factor to be 1 + return 1; + } + // only consider vectorizable inputs, // needs to check if it's already in the list to avoid duplication since a tv // may be both input and output, e.g. NVFuserTest.FusionIssue2372_CUDA @@ -518,7 +537,8 @@ std::unique_ptr getPointwiseHeuristics( total_blocks, params->vectorization_factor * max_dtype_size_bit_for_vectorization, divisible_split, - vectorizable_inputs_outputs_entry.get()); + vectorizable_inputs_outputs_entry.get(), + data_cache); if (is_outer_broadcast_dominated) { params->unroll_factor_outer = unroll_factor; @@ -967,6 +987,17 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference_tv); + // We first vectorize the quantized outputs of the block quantization ops. + // We then convert the vectorized ID to group ID. + // We do so as the runtime function for block quantization expects 2/4/8 + // elements per thread. + auto bq_ops = ir_utils::getOpsOfType(fusion); + std::vector nvfp4_quantized_outputs = {}; + for (auto bq_op : bq_ops) { + nvfp4_quantized_outputs.push_back( + bq_op->quantizedOutput()->as()); + } + if (pparams->vectorization_factor > 1) { // Grab all tensor views that should be vectorized auto inputs_outputs = @@ -998,6 +1029,13 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } } } + + // Vectorize nvfp4 quantized outputs. + // We will later change the vectorized ID to group ID + for (auto quantized_output : nvfp4_quantized_outputs) { + vectorized_tvs.emplace_back(quantized_output); + } + if (!vectorized_tvs.empty()) { // Aggressively mark with vectorized and cleanup later. That way we // don't have to manually specify parallelization outside the reference. @@ -1007,6 +1045,15 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (!should_vectorize_reference_tv) { vectorize_id->parallelize(ParallelType::Serial); } + + // Change vectorized IDs to group IDs for quantized outputs + for (auto quantized_output : nvfp4_quantized_outputs) { + for (auto id : quantized_output->getLoopDomain()) { + if (id->getParallelType() == ParallelType::Vectorize) { + id->parallelize(ParallelType::Group); + } + } + } } } diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 01f15de0554..e16c23d7814 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -280,6 +280,8 @@ template class HeuristicDataCacheEntry< template class HeuristicDataCacheEntry< HeuristicCompileTime::UnrollableInputsAndOutputs>; template class HeuristicDataCacheEntry; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::HasBlockQuantizationOps>; template class HeuristicDataCacheEntry< HeuristicCompileTime::PersistentBufferInfo>; template class HeuristicDataCacheEntry< diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index a47209a80e5..490ffd6d036 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -841,6 +841,21 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( return false; } +// Returns true if the output of the block quantization op +// is not the fusion/segment output. +bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { + for (auto expr : fusion->exprs()) { + if (expr->isA()) { + auto block_scales = + expr->as()->blockScales()->as(); + if (!block_scales->isFusionOutput()) { + return true; + } + } + } + return false; +} + // Checks if any broadcasts are resolved after a reduction, this shouldn't be // accepted in the single reduction or multi-reduction scheduler bool SchedulerTopologyChecker::hasPostReductionBCast(Fusion* fusion) { diff --git a/csrc/scheduler/registry_utils.h b/csrc/scheduler/registry_utils.h index 8123f2c3639..108dc247301 100644 --- a/csrc/scheduler/registry_utils.h +++ b/csrc/scheduler/registry_utils.h @@ -72,6 +72,10 @@ PrimDataType getIndexTypeOfKernel( const KernelArgumentHolder& inputs, ExpressionEvaluator& ee); +// Check if the block scales output of Block Quantization Op +// is a segment output. +bool hasNonTerminalBlockQuantizeOp(Fusion* fusion); + class SchedulerTopologyChecker { public: // Checks if any broadcasts are resolved after a reduction that don't follow diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index c36b1a6de22..55abe2d0914 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -409,7 +409,18 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { for (auto output_tv : ir_utils::filterByType(fusion_->outputs())) { // no need to check for self. - if (output_tv == tv) { + // If this is the block scaling factor output of a BlockQuantizationOp, + // then we skip the check as we only consider the quantized output of the + // BlockQuantizationOp when looking for a reference tensor. This is because + // the two outputs of block quantization op are not symmetrical and the + // logical domains of the scaling factor is not completely mapped. + if (output_tv == tv || + (output_tv->definition() && + output_tv->definition()->isA() && + output_tv == + output_tv->definition() + ->as() + ->blockScales())) { continue; } if (!areAllTargetIdsCoveredBy(output_tv, tv)) { diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index ebacd705019..37236dcedfb 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1380,8 +1380,14 @@ std::vector> cacheAndForkOutputs( // the output of ScatterOp must on the global memory due to the random // or atomic access. Similarly, PreprocessGroupedMatmulInputSf requires // direct write to global memory because of random access. + // The output of block quantization has to be in global memory. This is + // because this op is implemented via a runtime function that write the + // scaling factors to global memory. output->definition() - ->isOneOf()) { + ->isOneOf() || + (output->definition()->isA() && + output->definition()->as()->blockScales() == + output)) { continue; } if (!output->uses().empty()) { diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 2be12491a2c..55069a3d22e 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -108,7 +108,7 @@ constexpr double F8E4M3_MAX = 448.0; class NVFP4QuantizeTest : public BlackwellBase, public ::testing::WithParamInterface {}; namespace { -void createNVFP4QunatizationFusion(Fusion* fusion, DataType data_hp_dtype) { +void createNVFP4QuantizationFusion(Fusion* fusion, DataType data_hp_dtype) { auto tv_data_hp = makeContigTensor(2, data_hp_dtype); fusion->addInput(tv_data_hp); @@ -155,7 +155,7 @@ TEST_P(NVFP4QuantizeTest, WithoutPerTensorAmax) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -187,7 +187,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise) { // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -297,7 +297,7 @@ TEST_P(BlockQuantizationTest, ScheduleAsPointwise2D) { // Baseline implementation std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); - createNVFP4QunatizationFusion(fusion.get(), data_hp_dtype); + createNVFP4QuantizationFusion(fusion.get(), data_hp_dtype); FusionExecutorCache fec(std::move(fusion)); @@ -689,6 +689,143 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { "IDs from the logical domain for BlockQuantizationOp"))); } +struct BlockQuantizationSchedulingTestParams { + DataType data_type; + int m; + int n; +}; + +class BlockQuantizationSchedulingTest + : public BlackwellBase, + public ::testing::WithParamInterface< + BlockQuantizationSchedulingTestParams> {}; + +TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { + auto params = GetParam(); + auto data_type = params.data_type; + const int m = params.m; + const int n = params.n; + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + createNVFP4QuantizationFusion(fusion.get(), data_type); + + FusionExecutorCache fec(std::move(fusion)); + + std::vector inputs; + inputs.push_back(at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)) + .to(data_type_to_aten(data_type))); + auto outputs_baseline = fec.runFusionWithInputs(inputs); + + auto baseline_block_scales = outputs_baseline[0].as(); + auto baseline_quantized_tensor = outputs_baseline[1].as(); + + auto baseline_block_scales_cpu = baseline_block_scales.cpu(); + auto baseline_quantized_tensor_cpu = baseline_quantized_tensor.cpu(); + + const uint8_t* baseline_block_scales_data = + static_cast(baseline_block_scales_cpu.data_ptr()); + const uint8_t* baseline_quantized_data = + static_cast(baseline_quantized_tensor_cpu.data_ptr()); + + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + + auto tv_in_1 = makeContigTensor(2, data_type); + fusion_new_op->addInput(tv_in_1); + + auto quantization_results = blockQuantize(tv_in_1); + + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput(quantization_results.quantized_tensor); + + FusionExecutorCache executor_cache(std::move(fusion_new_op)); + auto outputs_new_op = executor_cache.runFusionWithInputs(inputs); + + // Verify we got the expected outputs + auto block_scales_output = outputs_new_op[0].as(); + auto quantized_tensor_output = outputs_new_op[1].as(); + + // Move tensors from GPU to CPU + auto block_scales_cpu = block_scales_output.cpu(); + auto quantized_tensor_cpu = quantized_tensor_output.cpu(); + + auto block_scales_bytes = (m * n) / block_size; + auto quantized_tensor_bytes = (m * n) / 2; + + const uint8_t* block_scales_data = + static_cast(block_scales_cpu.data_ptr()); + for (int i = 0; i < block_scales_bytes; ++i) { + EXPECT_EQ( + block_scales_data[i], + baseline_block_scales_data[i]); // Compare with baseline + } + + const uint8_t* quantized_data = + static_cast(quantized_tensor_cpu.data_ptr()); + for (int i = 0; i < quantized_tensor_bytes; ++i) { + EXPECT_EQ( + quantized_data[i], + baseline_quantized_data[i]); // Compare with baseline + } +} + +class BlockQuantizationCanScheduleTests : public BlackwellBase {}; + +TEST_F( + BlockQuantizationCanScheduleTests, + CanRuntimeScheduleFailFromNoVectorization) { + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv_data_hp = makeContigTensor(2, DataType::Float); + fusion->addInput(tv_data_hp); + + auto t0 = set(tv_data_hp); + auto quantization_results = blockQuantize(t0); + auto t_out = set(quantization_results.quantized_tensor); + + fusion->addOutput(quantization_results.block_scales); + fusion->addOutput(t_out); + + // Create misaligned tensor directly on GPU using custom CUDA allocation + size_t element_size = 4; + int m = 1024; + int n = 1024; + + size_t total_elements = m * n; + size_t buffer_size = + total_elements * element_size + 16; // Extra bytes for misalignment + + // Allocate GPU memory with extra space + void* gpu_ptr; + cudaMalloc(&gpu_ptr, buffer_size); + + // Create tensor from GPU memory at offset of 4 bytes + void* misaligned_ptr = static_cast(gpu_ptr) + 4; + auto misaligned_gpu_tensor = at::from_blob( + misaligned_ptr, + {m, n}, + at::TensorOptions() + .dtype(data_type_to_aten(DataType::Float)) + .device(at::kCUDA)); + + auto good_input = at::randn({m, n}, at::device(at::kCUDA).dtype(at::kFloat)); + + // Expect failure as the input tensor can't be vectorized + // and we need vectorization > 2 + SchedulerRuntimeInfo runtime_info(fusion.get(), {misaligned_gpu_tensor}); + ASSERT_FALSE(Schedule::canSchedule( + SchedulerType::PointWise, fusion.get(), runtime_info)); + + SchedulerRuntimeInfo runtime_info_new(fusion.get(), {good_input}); + ASSERT_TRUE(Schedule::canSchedule( + SchedulerType::PointWise, fusion.get(), runtime_info_new)); + + if (gpu_ptr) + cudaFree(gpu_ptr); +} + TEST_P(NVFP4QuantizeTest, SwizzledOuputAndWithoutPerTensorAmax) { auto data_hp_dtype = GetParam(); @@ -855,4 +992,24 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +INSTANTIATE_TEST_SUITE_P( + , + BlockQuantizationSchedulingTest, + ::testing::Values( + BlockQuantizationSchedulingTestParams{DataType::Float, 1024, 1024}, + BlockQuantizationSchedulingTestParams{DataType::Float, 128, 64}, + BlockQuantizationSchedulingTestParams{DataType::Float, 2048, 128}, + BlockQuantizationSchedulingTestParams{DataType::Float, 2048, 2048}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 1024, 1024}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 128, 64}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 2048, 128}, + BlockQuantizationSchedulingTestParams{DataType::BFloat16, 2048, 2048}), + [](const testing::TestParamInfo& + info) { + std::ostringstream name; + name << info.param.data_type << "_" << info.param.m << "x" + << info.param.n; + return name.str(); + }); + } // namespace nvfuser