Skip to content

Commit 0927880

Browse files
authored
Use linked hash map to update logical domain in rfactor (#5100)
This approach should be faster and simplifies the update logic. For PR #5090
1 parent 9bc115a commit 0927880

File tree

1 file changed

+38
-40
lines changed

1 file changed

+38
-40
lines changed

csrc/transform_rfactor.cpp

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <ir/iostream.h>
1414
#include <ir/utils.h>
1515
#include <iter_visitor.h>
16+
#include <linked_hash_map.h>
1617
#include <ops/arith.h>
1718

1819
namespace nvfuser {
@@ -55,40 +56,26 @@ namespace {
5556
// in this replay.
5657
class ReplayRFactor : public ReplayTransformations {
5758
private:
58-
// Perform the update of the logical domain by replacing "replace0" with
59-
// "with0" and if not nullptr "with1", also removes "replace1" if not nullptr.
60-
void updateRFactorDomain(
61-
IterDomain* replace0,
62-
IterDomain* replace1,
63-
IterDomain* with0,
64-
IterDomain* with1) {
65-
NVF_ERROR(
66-
with0 != nullptr,
67-
"The first provided IterDomain should be a real pointer,",
68-
" the second iter domain provided can be a nullptr.");
69-
auto pos =
70-
std::find(logical_domain_.begin(), logical_domain_.end(), replace0);
71-
NVF_ERROR(
72-
pos != logical_domain_.end(),
73-
"Could not find iter domain: ",
74-
replace0->toString(),
75-
" in the logical domain to replace.");
76-
logical_domain_.insert(pos, with0);
77-
if (with1 != nullptr) {
78-
pos = std::find(logical_domain_.begin(), logical_domain_.end(), replace0);
79-
logical_domain_.insert(pos, with1);
59+
void splitId(Split* split) {
60+
auto it = logical_domain_.erase(split->in()).second;
61+
logical_domain_.insert(it, split->outer(), std::monostate());
62+
logical_domain_.insert(it, split->inner(), std::monostate());
63+
}
64+
65+
void mergeId(Merge* merge) {
66+
auto outer_it = logical_domain_.erase(merge->outer()).second;
67+
logical_domain_.insert(outer_it, merge->out(), std::monostate());
68+
logical_domain_.erase(merge->inner());
69+
}
70+
71+
void updateRFactorDomain(Expr* expr) {
72+
if (expr->isA<Split>()) {
73+
splitId(expr->as<Split>());
8074
}
81-
pos = std::find(logical_domain_.begin(), logical_domain_.end(), replace0);
82-
logical_domain_.erase(pos);
83-
if (replace1 != nullptr) {
84-
pos = std::find(logical_domain_.begin(), logical_domain_.end(), replace1);
85-
NVF_ERROR(
86-
pos != logical_domain_.end(),
87-
"Wanted to replace ",
88-
replace1->toString(),
89-
" but it's not in the logical domain.");
90-
logical_domain_.erase(pos);
75+
if (expr->isA<Merge>()) {
76+
mergeId(expr->as<Merge>());
9177
}
78+
NVF_ERROR("Unrecognized expression: ", expr->toString());
9279
}
9380

9481
// Took a good bit of this from ReplayTransformations::handle(Split...)
@@ -153,7 +140,7 @@ class ReplayRFactor : public ReplayTransformations {
153140
id_map_[s->inner()] = idi;
154141

155142
if (static_logical_ids_.count(s->in())) {
156-
updateRFactorDomain(s->in(), nullptr, s->outer(), s->inner());
143+
updateRFactorDomain(s);
157144
}
158145
}
159146

@@ -211,7 +198,7 @@ class ReplayRFactor : public ReplayTransformations {
211198
static_logical_ids_.count(m->outer()),
212199
"If one input to a merge is a static logical id, the other must be "
213200
"as well.");
214-
updateRFactorDomain(m->outer(), m->inner(), m->out(), nullptr);
201+
updateRFactorDomain(m);
215202
}
216203
}
217204

@@ -241,7 +228,7 @@ class ReplayRFactor : public ReplayTransformations {
241228
// The updated domain matching the producer's logical domain. This rfactor
242229
// domain is relative to the iter domains in the origianl_domain and must be
243230
// updated to grab the mapped id's later.
244-
std::vector<IterDomain*> logical_domain_;
231+
LinkedHashMap<IterDomain*, std::monostate> logical_domain_;
245232

246233
ReplayRFactor(
247234
// Original domain the rfactor is in reference to.
@@ -257,8 +244,7 @@ class ReplayRFactor : public ReplayTransformations {
257244
std::unordered_set<IterDomain*> static_logical_ids)
258245
: ReplayTransformations(original_domain->loop(), std::move(id_map)),
259246
rfactor_axes_(std::move(rfactor_axes)),
260-
static_logical_ids_(std::move(static_logical_ids)),
261-
logical_domain_(original_domain->logical()) {
247+
static_logical_ids_(std::move(static_logical_ids)) {
262248
const auto all_dep_vals = DependencyCheck::getAllValsBetween(
263249
{original_domain->maybeRoot().begin(),
264250
original_domain->maybeRoot().end()},
@@ -267,8 +253,19 @@ class ReplayRFactor : public ReplayTransformations {
267253
auto all_dep_ids = ir_utils::filterByType<IterDomain>(all_dep_vals);
268254
rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end());
269255

256+
for (IterDomain* id : original_domain->logical()) {
257+
logical_domain_.pushBack(id, std::monostate());
258+
}
259+
270260
setErrorOnFailure(false);
271261
}
262+
263+
std::vector<IterDomain*> logical() const {
264+
auto logical_ids = std::views::keys(logical_domain_);
265+
std::vector<IterDomain*> transformed_logical(
266+
logical_ids.begin(), logical_ids.end());
267+
return transformed_logical;
268+
}
272269
};
273270

274271
} // namespace
@@ -421,10 +418,11 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
421418
// Specify the logical domain of the producer which will match the consumer
422419
// root domain.
423420
std::vector<IterDomain*> new_producer_logical_domain;
424-
new_producer_logical_domain.reserve(replay_rfactor.logical_domain_.size());
421+
std::vector<IterDomain*> transformed_original_logical =
422+
replay_rfactor.logical();
425423
std::transform(
426-
replay_rfactor.logical_domain_.begin(),
427-
replay_rfactor.logical_domain_.end(),
424+
transformed_original_logical.begin(),
425+
transformed_original_logical.end(),
428426
std::back_inserter(new_producer_logical_domain),
429427
[&](IterDomain* id) {
430428
auto replayed_id_it = original_to_producer_id_map.find(id);

0 commit comments

Comments
 (0)