Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 61 additions & 46 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,48 @@ class ReplayRFactor : public ReplayTransformations {
}
};

// Use the `replay_to_target_map` to replay the `replay_domain`.
// `ignore_ids` is the set of ids that should not be replayed. If
// `propagate_padding = true`, padding to multiple of warp is applied to the
// replayed ids. If `propagate_parallelization = true`, replayed id is
// parallelized to the original id's parallel type. Device and stream parallel
// types are always be preserved in replay.
std::vector<IterDomain*> replayDomain(
const std::vector<IterDomain*>& replay_domain,
const std::unordered_map<IterDomain*, IterDomain*>& replay_to_target_map,
const std::unordered_set<IterDomain*>& ignore_ids = {},
bool propagate_padding = false,
bool propagate_parallelization = false) {
std::vector<IterDomain*> target_domain;
target_domain.reserve(replay_domain.size());
for (const auto& replay_id :
replay_domain | std::views::filter([&ignore_ids](IterDomain* replay_id) {
return !ignore_ids.contains(replay_id);
})) {
auto target_id_it = replay_to_target_map.find(replay_id);
NVF_ERROR(
target_id_it != replay_to_target_map.end(),
"Error during rfactor replay, missing an axis.",
replay_id->toString());
IterDomain* target_id = target_id_it->second;

// Device and stream parallel types should always be preserved in replay.
// Other parallel types are only relevant to replay of the loop domain.
if (propagate_parallelization || replay_id->isDeviceDim() ||
replay_id->isStream()) {
target_id->parallelize(replay_id->getParallelType());
}

if (propagate_padding) {
if (replay_id->hasPaddingToMultipleOfWarp()) {
target_id->padToMultipleOfWarp(replay_id->getMaybeSizeAfterPadding());
}
}
target_domain.push_back(target_id);
}
return target_domain;
}

} // namespace

std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
Expand Down Expand Up @@ -371,7 +413,7 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
} else {
new_producer_root[i] = id->cloneWithoutRFactor();
}
original_to_producer_root_map[id] = new_producer_root[i++];
original_to_producer_root_map[id] = new_producer_root[i];
}
}

Expand All @@ -398,39 +440,23 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
std::unordered_map<IterDomain*, IterDomain*> original_to_producer_id_map =
replay_rfactor.getReplay();

std::vector<IterDomain*> new_producer_domain(original_td->nDims(), nullptr);
{
for (auto i : arange(original_td->nDims())) {
auto orig_id = original_td->axis(i);
auto replayed_id_it = original_to_producer_id_map.find(orig_id);
NVF_ERROR(
replayed_id_it != original_to_producer_id_map.end(),
"Error during rfactor replay, missing an axis.");
auto replayed_id = replayed_id_it->second;
replayed_id->parallelize(orig_id->getParallelType());
if (orig_id->hasPaddingToMultipleOfWarp()) {
replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding());
}
new_producer_domain[i++] = replayed_id;
}
}
std::vector<IterDomain*> new_producer_domain = replayDomain(
original_td->loop(),
original_to_producer_id_map,
/*ignore_ids=*/{},
/*propagate_padding=*/true,
/*propagate_parallelization=*/true);

// Specify the logical domain of the producer which will match the consumer
// root domain.
std::vector<IterDomain*> new_producer_logical_domain;
std::vector<IterDomain*> transformed_original_logical =
replay_rfactor.logical();
std::transform(
transformed_original_logical.begin(),
transformed_original_logical.end(),
std::back_inserter(new_producer_logical_domain),
[&](IterDomain* id) {
auto replayed_id_it = original_to_producer_id_map.find(id);
NVF_ERROR(
replayed_id_it != original_to_producer_id_map.end(),
"Error during rfactor replay, missing an axis.");
return replayed_id_it->second;
});
std::vector<IterDomain*> new_producer_logical_domain = replayDomain(
transformed_original_logical,
original_to_producer_id_map,
/*ignore_ids=*/{},
/*propagate_padding=*/false,
/*propagate_parallelization=*/false);

auto* producer_domain = IrBuilder::createInContainer<TensorDomain>(
original_td->container(),
Expand Down Expand Up @@ -475,23 +501,12 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(

auto original_to_consumer_map = consumer_replay.getReplay();

std::vector<IterDomain*> new_consumer_domain;

{
// Construct the new consumer domain
for (auto i : arange(original_td->nDims())) {
auto orig_id = original_td->axis(i);
auto replayed_id_it = original_to_consumer_map.find(orig_id);
if (replayed_id_it != original_to_consumer_map.end()) {
auto replayed_id = replayed_id_it->second;
new_consumer_domain.push_back(replayed_id);
replayed_id->parallelize(orig_id->getParallelType());
if (orig_id->hasPaddingToMultipleOfWarp()) {
replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding());
}
}
}
}
std::vector<IterDomain*> new_consumer_domain = replayDomain(
original_td->loop(),
original_to_consumer_map,
/*ignore_ids=*/rfactor_axes,
/*propagate_padding=*/true,
/*propagate_parallelization=*/true);

auto consumer_domain = IrBuilder::createInContainer<TensorDomain>(
original_td->container(),
Expand Down