Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,224 @@ class ExprValidator : public OptOutDispatch {
}
}

// Given a set of loop domain IterDomains, find their logical domain origins
std::vector<IterDomain*> findLogicalDomainOrigins(
const std::vector<IterDomain*>& loop_domain_ids,
const TensorView* tv) {
// Get the logical domain to use as the target/boundary
const auto& logical_domain = tv->getLogicalDomain();

// Use IterVisitor to find inputs to the loop domain IDs,
// bounded by the logical domain
std::vector<Val*> inputs_as_vals = IterVisitor::getInputsTo(
{loop_domain_ids.begin(), loop_domain_ids.end()},
{logical_domain.begin(), logical_domain.end()});

// Convert back to IterDomains
std::vector<IterDomain*> logical_origins;
for (auto val : inputs_as_vals) {
logical_origins.push_back(val->as<IterDomain>());
}

return logical_origins;
}

// I'd like to check that the inner dimension of the input
// is divisble by 16.
void handle(BlockQuantizationOp* bqop) final {
auto inp_tv = bqop->input(0)->as<TensorView>();
auto quantized_output = bqop->quantizedOutput()->as<TensorView>();
auto block_scaling_factor = bqop->blockScales()->as<TensorView>();

NVF_ERROR_EQ(
inp_tv->getMemoryType(),
MemoryType::Local,
"Input must be a local memory tensor. Found: ",
inp_tv->getMemoryType());

NVF_ERROR_EQ(
quantized_output->getMemoryType(),
MemoryType::Local,
"Quantized output must be a local memory tensor. Found: ",
quantized_output->getMemoryType());

NVF_ERROR_EQ(
block_scaling_factor->getMemoryType(),
MemoryType::Global,
"Block scaling factor must be a global memory tensor. Found: ",
block_scaling_factor->getMemoryType());

// outputs have the same allocation domain
// as the loop domain. This has to be later
// relaxed for the scaling factors.
NVF_ERROR(
quantized_output->hasAllocation() == false,
"Quantized output must not have an allocation domain.");
NVF_ERROR(
block_scaling_factor->hasAllocation() == false,
"Block scaling factor must not have an allocation domain.");

// Check that it either had vectorized ID or grouped ID
// not both and the extent is either 4(FP32) or 8(BF16)
IterDomain* grouped_or_vector_id = nullptr;
IterDomain* thread_x = nullptr;
IterDomain* block_x = nullptr;
IterDomain* thread_y = nullptr;
IterDomain* block_y = nullptr;
IterDomain* thread_z = nullptr;
IterDomain* block_z = nullptr;

for (const auto& loop_id : block_scaling_factor->getLoopDomain()) {
if (loop_id->getParallelType() == ParallelType::Group ||
loop_id->getParallelType() == ParallelType::Vectorize) {
NVF_ERROR(
grouped_or_vector_id == nullptr,
"Multiple IDs found to be grouped/vectorized");
grouped_or_vector_id = loop_id;
}
}

auto parallel_domains_map =
ir_utils::getParallelDomains(block_scaling_factor);
if (parallel_domains_map.find(ParallelType::TIDx) !=
parallel_domains_map.end()) {
thread_x = parallel_domains_map.at(ParallelType::TIDx);
}
if (parallel_domains_map.find(ParallelType::BIDx) !=
parallel_domains_map.end()) {
block_x = parallel_domains_map.at(ParallelType::BIDx);
}
if (parallel_domains_map.find(ParallelType::TIDy) !=
parallel_domains_map.end()) {
thread_y = parallel_domains_map.at(ParallelType::TIDy);
}
if (parallel_domains_map.find(ParallelType::BIDy) !=
parallel_domains_map.end()) {
block_y = parallel_domains_map.at(ParallelType::BIDy);
}
if (parallel_domains_map.find(ParallelType::TIDz) !=
parallel_domains_map.end()) {
thread_z = parallel_domains_map.at(ParallelType::TIDz);
}
if (parallel_domains_map.find(ParallelType::BIDz) !=
parallel_domains_map.end()) {
block_z = parallel_domains_map.at(ParallelType::BIDz);
}

NVF_ERROR(
grouped_or_vector_id != nullptr,
"One of the output IDs must be grouped or vectorized for "
"BlockQuantizationOp: ",
bqop->toString());

NVF_ERROR(
thread_x != nullptr && block_x != nullptr,
"Need to have both TIDx and BIDx when using BlockQuantizationOp: ",
bqop->toString());

NVF_ERROR(
!thread_z && !block_z,
"Parallelization along z axis is not supported for "
"BlockQuantizationOp: ",
bqop->toString());

bool is_2d_scheduled =
(thread_y != nullptr || block_y != nullptr) ? true : false;

auto inner_extent =
grouped_or_vector_id->extent()->evaluate().as<int64_t>();
auto input_dtype = inp_tv->dtype();

NVF_ERROR(
(inner_extent == 4 && input_dtype == DataType::Float) ||
(inner_extent == 8 && input_dtype == DataType::BFloat16),
"The vectorized/grouped dimension must be 4 (FP32) or 8 "
"(BF16). Found: ",
inner_extent,
". Expr: ",
bqop->toString());

// Find the logical domain IDs that correspond to these loop IDs.
// Then we check that the logical domain IDs are the inner-most
// IDs.
auto input_logical_domains_ids = findLogicalDomainOrigins(
{grouped_or_vector_id, thread_x, block_x}, block_scaling_factor);

// Get the size of input logical domains
size_t num_input_logical_domains = input_logical_domains_ids.size();

// Get the same number of elements from the innermost logical domain
const auto& logical_domain = block_scaling_factor->getLogicalDomain();
std::vector<IterDomain*> innermost_logical_domains;

// Extract from the rightmost (innermost) positions
for (int64_t i = logical_domain.size() - 1;
i >= 0 && innermost_logical_domains.size() < num_input_logical_domains;
i--) {
auto logical_id = logical_domain[i];
if (!logical_id->isReduction() && !logical_id->isBroadcast()) {
innermost_logical_domains.insert(
innermost_logical_domains.begin(), logical_id);
}
}

// Validate that input_logical_domains_ids and innermost_logical_domains
// contain the same IterDomains
std::unordered_set<IterDomain*> input_logical_set(
input_logical_domains_ids.begin(), input_logical_domains_ids.end());
std::unordered_set<IterDomain*> innermost_logical_set(
innermost_logical_domains.begin(), innermost_logical_domains.end());

NVF_ERROR(
input_logical_set == innermost_logical_set,
"Input logical domain IDs do not match the innermost logical domains "
"for BlockQuantizationOp: ",
bqop->toString(),
". Expected innermost domains: ",
toDelimitedString(innermost_logical_domains),
". Found input logical domains: ",
toDelimitedString(input_logical_domains_ids));

// If it's 2D scheduled, the we get the IDs from the logical domain
// that correspond to blockIdx.y and threadIdx.y. We make sure the
// IDs from the logical domain don't share any ID with those from the
// thread/block for x-dimension was derived.
if (is_2d_scheduled) {
std::vector<IterDomain*> input_logical_domains_ids_2d = {};
for (auto id : {thread_y, block_y}) {
if (id) {
input_logical_domains_ids_2d.push_back(id);
}
}

auto input_logical_domains_ids_y = findLogicalDomainOrigins(
input_logical_domains_ids_2d, block_scaling_factor);

// Validate that input_logical_domains_ids and input_logical_domains_ids_y
// don't have any elements in common
std::unordered_set<IterDomain*> input_logical_set_x(
input_logical_domains_ids.begin(), input_logical_domains_ids.end());
std::unordered_set<IterDomain*> input_logical_set_y(
input_logical_domains_ids_y.begin(),
input_logical_domains_ids_y.end());

for (const auto& id : input_logical_set_x) {
NVF_ERROR(
input_logical_set_y.find(id) == input_logical_set_y.end(),
"Input logical domain IDs for X and Y dimensions have overlapping "
"elements "
"for BlockQuantizationOp: ",
bqop->toString(),
". Overlapping IterDomain: ",
id->toString(),
". X logical domains: ",
toDelimitedString(input_logical_domains_ids),
". Y logical domains: ",
toDelimitedString(input_logical_domains_ids_y));
}
}
}

static void validateUnitStride(
TensorView* tv,
const std::vector<IterDomain*>& alloc_domain,
Expand Down
2 changes: 1 addition & 1 deletion csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ std::pair<std::unordered_set<IterDomain*>, bool> getNonMappingDomainInfo(
} else if (
auto bqop =
dynamic_cast<BlockQuantizationOp*>(consumer_tv->definition())) {
if (producer_tv == bqop->in()) {
if (producer_tv == bqop->in() && consumer_tv == bqop->blockScales()) {
auto producer_logical =
TensorDomain::noReductions(producer_tv->getLogicalDomain());
auto last_logical_dim = producer_tv->getLogicalDomain().size() - 1;
Expand Down
28 changes: 27 additions & 1 deletion csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,16 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
(int64_t)vectorizable_inputs_outputs_entry.get().size()) >>
2),
(int64_t)1));

auto all_exprs = fusion->exprs();
auto fusion_has_block_quantization =
ir_utils::filterByType<BlockQuantizationOp>(all_exprs).size() > 0;

// Don't vectorize at the cost of getting a full wave on the GPU
if (n_elems < device_multiprocessor_count * kThreadX && max_vect_factor > 1) {
// unless there's a block quantization op which must be vectorized.
if ((n_elems < device_multiprocessor_count * kThreadX &&
max_vect_factor > 1) &&
!fusion_has_block_quantization) {
max_vect_factor = std::min(
max_vect_factor,
ceilDiv(n_elems, device_multiprocessor_count * kThreadX));
Expand Down Expand Up @@ -806,6 +814,17 @@ bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

// The block scales output of the Block Quantization Op
// should be a segment output as it is written to the global
// memory.
if (registry_utils::hasNonTerminalBlockQuantizeOp(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"no support for block quantization where block scales is not a fusion "
"output");
return false;
}

return true;
}

Expand Down Expand Up @@ -1234,6 +1253,13 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
}
}
}

auto bq_ops = ir_utils::getOpsOfType<BlockQuantizationOp>(fusion);
for (auto bq_op : bq_ops) {
vectorized_tvs.emplace_back(bq_op->quantizedOutput()->as<TensorView>());
vectorized_tvs.emplace_back(bq_op->blockScales()->as<TensorView>());
}

if (!vectorized_tvs.empty()) {
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
Expand Down
13 changes: 13 additions & 0 deletions csrc/scheduler/registry_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ PrimDataType getIndexTypeOfKernel(
return PrimDataType::Int32;
}

bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) {
for (auto expr : fusion->exprs()) {
if (expr->isA<BlockQuantizationOp>()) {
auto block_scales =
expr->as<BlockQuantizationOp>()->blockScales()->as<TensorView>();
if (!block_scales->isFusionOutput()) {
return true;
}
}
}
return false;
}

bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast(
Fusion* fusion) {
auto all_vals = fusion->usedMathVals();
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/registry_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ bool rejectScheduleForMemoryPromotion(
Fusion* fusion,
SchedulerType scheduler_type);

// Check to see if the block scales output of Block Quantization Op
// is a segment output.
bool hasNonTerminalBlockQuantizeOp(Fusion* fusion);

bool isConnectedFusionGraph(Fusion* fusion);

// Returns if a fusion cannot transformed into a consistent format since we
Expand Down
54 changes: 53 additions & 1 deletion csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <ir/utils.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>

Expand Down Expand Up @@ -377,6 +378,57 @@ IterDomain* DomainMap::anyMapped(
return nullptr;
}

namespace {

// checks to see if tv is a block scale ouput of the BlockQuantizationOp.
// If not, it traverses up the consumer->producer chain to see if it is.
bool isTransitiveBlockScaleOuput(TensorView* tv) {
// Check if current tv is directly a block scale output
if (tv->definition() != nullptr &&
tv->definition()->isA<BlockQuantizationOp>()) {
auto block_quant_op = tv->definition()->as<BlockQuantizationOp>();
if (block_quant_op->blockScales() == tv) {
return true;
}
}

// Traverse up the producer chain
auto current_tv = tv;
std::unordered_set<TensorView*> visited; // To prevent infinite loops

while (current_tv != nullptr && visited.find(current_tv) == visited.end()) {
visited.insert(current_tv);

// Get all producers of current tv
auto producers = ir_utils::producerTvsOf(current_tv);

// Check each producer
for (auto producer : producers) {
// Check if this producer is a block scale output
if (producer->definition() != nullptr &&
producer->definition()->isA<BlockQuantizationOp>()) {
auto block_quant_op = producer->definition()->as<BlockQuantizationOp>();
if (block_quant_op->blockScales() == producer) {
return true;
}
}
}

// Move to the first producer for continued traversal
// If there are multiple producers, we check the first one for simplicity
// This could be extended to check all paths if needed
if (!producers.empty()) {
current_tv = producers[0];
} else {
current_tv = nullptr; // No more producers to check
}
}

return false;
}

} // namespace

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input and
// output
Expand All @@ -401,7 +453,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const {
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
// no need to check for self.
if (output_tv == tv) {
if (output_tv == tv || isTransitiveBlockScaleOuput(output_tv)) {
continue;
}
if (!areAllTargetIdsCoveredBy(output_tv, tv)) {
Expand Down
Loading