From cea43e76ebeff792c84d8805b7e58dc330283756 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 1 Sep 2025 12:32:17 -0700 Subject: [PATCH 1/3] separate replayDomain to allow reuse --- csrc/transform_rfactor.cpp | 114 ++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 47 deletions(-) diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 3aeee42a053..2512ab65e60 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -271,6 +271,54 @@ class ReplayRFactor : public ReplayTransformations { } }; +// Empty set for replayDomain calls that don't need to ignore any IDs +static const std::unordered_set kEmptyIgnoreIds{}; + +// Use the `replay_to_target_map` to replay the `replay_domain`. +// `ignore_rfactor_ids` is true for consumers where the replay will not have +// these ids since they are already reduced. If `propagate_padding = true`, +// padding to multiple of warp is applied to the replayed ids. If +// `propagate_parallelization = true`, replayed id is parallelized to the +// original id's parallel type. Device and stream parallel types are always be +// preserved in replay. +std::vector replayDomain( + const std::vector& replay_domain, + std::unordered_map& replay_to_target_map, + const std::unordered_set& ignore_ids = kEmptyIgnoreIds, + bool propagate_padding = false, + bool propagate_parallelization = false) { + std::vector target_domain; + target_domain.reserve(replay_domain.size()); + for (const auto& replay_id : replay_domain) { + auto target_id_it = replay_to_target_map.find(replay_id); + + if (ignore_ids.count(replay_id)) { + continue; + } + + NVF_ERROR( + target_id_it != replay_to_target_map.end(), + "Error during rfactor replay, missing an axis.", + replay_id->toString()); + IterDomain* target_id = target_id_it->second; + + // Device and stream parallel types should always be preserved in replay. + // Other parallel types are only relevant to replay of the loop domain. + if (propagate_parallelization || replay_id->isDeviceDim() || + replay_id->isStream()) { + target_id->parallelize(replay_id->getParallelType()); + } + + if (propagate_padding) { + if (replay_id->hasPaddingToMultipleOfWarp()) { + target_id->padToMultipleOfWarp(replay_id->getMaybeSizeAfterPadding()); + } + } + target_domain.push_back(target_id); + } + return target_domain; +} + } // namespace std::pair TransformRFactor::runReplay( @@ -374,7 +422,7 @@ std::pair TransformRFactor::runReplay( } else { new_producer_root[i] = id->cloneWithoutRFactor(); } - original_to_producer_root_map[id] = new_producer_root[i++]; + original_to_producer_root_map[id] = new_producer_root[i]; } } @@ -401,38 +449,21 @@ std::pair TransformRFactor::runReplay( std::unordered_map original_to_producer_id_map = replay_rfactor.getReplay(); - std::vector new_producer_domain(original_td->nDims(), nullptr); - { - for (auto i : arange(original_td->nDims())) { - auto orig_id = original_td->axis(i); - auto replayed_id_it = original_to_producer_id_map.find(orig_id); - NVF_ERROR( - replayed_id_it != original_to_producer_id_map.end(), - "Error during rfactor replay, missing an axis."); - auto replayed_id = replayed_id_it->second; - replayed_id->parallelize(orig_id->getParallelType()); - if (orig_id->hasPaddingToMultipleOfWarp()) { - replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding()); - } - new_producer_domain[i++] = replayed_id; - } - } + std::vector new_producer_domain = replayDomain( + original_td->loop(), + original_to_producer_id_map, + /*ignore_ids=*/kEmptyIgnoreIds, + /*propagate_padding=*/true, + /*propagate_parallelization=*/true); // Specify the logical domain of the producer which will match the consumer // root domain. - std::vector new_producer_logical_domain; - new_producer_logical_domain.reserve(replay_rfactor.logical_domain_.size()); - std::transform( - replay_rfactor.logical_domain_.begin(), - replay_rfactor.logical_domain_.end(), - std::back_inserter(new_producer_logical_domain), - [&](IterDomain* id) { - auto replayed_id_it = original_to_producer_id_map.find(id); - NVF_ERROR( - replayed_id_it != original_to_producer_id_map.end(), - "Error during rfactor replay, missing an axis."); - return replayed_id_it->second; - }); + std::vector new_producer_logical_domain = replayDomain( + replay_rfactor.logical_domain_, + original_to_producer_id_map, + /*ignore_ids=*/kEmptyIgnoreIds, + /*propagate_padding=*/false, + /*propagate_parallelization=*/false); auto* producer_domain = IrBuilder::createInContainer( original_td->container(), @@ -477,23 +508,12 @@ std::pair TransformRFactor::runReplay( auto original_to_consumer_map = consumer_replay.getReplay(); - std::vector new_consumer_domain; - - { - // Construct the new consumer domain - for (auto i : arange(original_td->nDims())) { - auto orig_id = original_td->axis(i); - auto replayed_id_it = original_to_consumer_map.find(orig_id); - if (replayed_id_it != original_to_consumer_map.end()) { - auto replayed_id = replayed_id_it->second; - new_consumer_domain.push_back(replayed_id); - replayed_id->parallelize(orig_id->getParallelType()); - if (orig_id->hasPaddingToMultipleOfWarp()) { - replayed_id->padToMultipleOfWarp(orig_id->getMaybeSizeAfterPadding()); - } - } - } - } + std::vector new_consumer_domain = replayDomain( + original_td->loop(), + original_to_consumer_map, + /*ignore_ids=*/rfactor_axes, + /*propagate_padding=*/true, + /*propagate_parallelization=*/true); auto consumer_domain = IrBuilder::createInContainer( original_td->container(), From db06013b9af4e11e7625095206871898aace7afa Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 2 Sep 2025 16:29:55 -0700 Subject: [PATCH 2/3] remove empty set decl --- csrc/transform_rfactor.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 2512ab65e60..4983407192b 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -270,10 +270,6 @@ class ReplayRFactor : public ReplayTransformations { setErrorOnFailure(false); } }; - -// Empty set for replayDomain calls that don't need to ignore any IDs -static const std::unordered_set kEmptyIgnoreIds{}; - // Use the `replay_to_target_map` to replay the `replay_domain`. // `ignore_rfactor_ids` is true for consumers where the replay will not have // these ids since they are already reduced. If `propagate_padding = true`, @@ -284,7 +280,7 @@ static const std::unordered_set kEmptyIgnoreIds{}; std::vector replayDomain( const std::vector& replay_domain, std::unordered_map& replay_to_target_map, - const std::unordered_set& ignore_ids = kEmptyIgnoreIds, + const std::unordered_set& ignore_ids = {}, bool propagate_padding = false, bool propagate_parallelization = false) { std::vector target_domain; @@ -452,7 +448,7 @@ std::pair TransformRFactor::runReplay( std::vector new_producer_domain = replayDomain( original_td->loop(), original_to_producer_id_map, - /*ignore_ids=*/kEmptyIgnoreIds, + /*ignore_ids=*/{}, /*propagate_padding=*/true, /*propagate_parallelization=*/true); @@ -461,7 +457,7 @@ std::pair TransformRFactor::runReplay( std::vector new_producer_logical_domain = replayDomain( replay_rfactor.logical_domain_, original_to_producer_id_map, - /*ignore_ids=*/kEmptyIgnoreIds, + /*ignore_ids=*/{}, /*propagate_padding=*/false, /*propagate_parallelization=*/false); From e086d0f99eb4c9141459dd5b433d3fe29480fed9 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 3 Sep 2025 13:02:44 -0700 Subject: [PATCH 3/3] review comments --- csrc/transform_rfactor.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index b10f7c92ce5..92b98b11510 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -267,28 +267,26 @@ class ReplayRFactor : public ReplayTransformations { return transformed_logical; } }; + // Use the `replay_to_target_map` to replay the `replay_domain`. -// `ignore_rfactor_ids` is true for consumers where the replay will not have -// these ids since they are already reduced. If `propagate_padding = true`, -// padding to multiple of warp is applied to the replayed ids. If -// `propagate_parallelization = true`, replayed id is parallelized to the -// original id's parallel type. Device and stream parallel types are always be -// preserved in replay. +// `ignore_ids` is the set of ids that should not be replayed. If +// `propagate_padding = true`, padding to multiple of warp is applied to the +// replayed ids. If `propagate_parallelization = true`, replayed id is +// parallelized to the original id's parallel type. Device and stream parallel +// types are always be preserved in replay. std::vector replayDomain( const std::vector& replay_domain, - std::unordered_map& replay_to_target_map, + const std::unordered_map& replay_to_target_map, const std::unordered_set& ignore_ids = {}, bool propagate_padding = false, bool propagate_parallelization = false) { std::vector target_domain; target_domain.reserve(replay_domain.size()); - for (const auto& replay_id : replay_domain) { + for (const auto& replay_id : + replay_domain | std::views::filter([&ignore_ids](IterDomain* replay_id) { + return !ignore_ids.contains(replay_id); + })) { auto target_id_it = replay_to_target_map.find(replay_id); - - if (ignore_ids.count(replay_id)) { - continue; - } - NVF_ERROR( target_id_it != replay_to_target_map.end(), "Error during rfactor replay, missing an axis.", @@ -451,8 +449,10 @@ std::pair TransformRFactor::runReplay( // Specify the logical domain of the producer which will match the consumer // root domain. + std::vector transformed_original_logical = + replay_rfactor.logical(); std::vector new_producer_logical_domain = replayDomain( - replay_rfactor.logical_domain_, + transformed_original_logical, original_to_producer_id_map, /*ignore_ids=*/{}, /*propagate_padding=*/false,