1313#include  < ir/iostream.h> 
1414#include  < ir/utils.h> 
1515#include  < iter_visitor.h> 
16+ #include  < linked_hash_map.h> 
1617#include  < ops/arith.h> 
1718
1819namespace  nvfuser  {
@@ -55,40 +56,26 @@ namespace {
5556//  in this replay.
5657class  ReplayRFactor  : public  ReplayTransformations  {
5758 private: 
58-   //  Perform the update of the logical domain by replacing "replace0" with
59-   //  "with0" and if not nullptr "with1", also removes "replace1" if not nullptr.
60-   void  updateRFactorDomain (
61-       IterDomain* replace0,
62-       IterDomain* replace1,
63-       IterDomain* with0,
64-       IterDomain* with1) {
65-     NVF_ERROR (
66-         with0 != nullptr ,
67-         " The first provided IterDomain should be a real pointer,"  ,
68-         "  the second iter domain provided can be a nullptr."  );
69-     auto  pos =
70-         std::find (logical_domain_.begin (), logical_domain_.end (), replace0);
71-     NVF_ERROR (
72-         pos != logical_domain_.end (),
73-         " Could not find iter domain: "  ,
74-         replace0->toString (),
75-         "  in the logical domain to replace."  );
76-     logical_domain_.insert (pos, with0);
77-     if  (with1 != nullptr ) {
78-       pos = std::find (logical_domain_.begin (), logical_domain_.end (), replace0);
79-       logical_domain_.insert (pos, with1);
59+   void  splitId (Split* split) {
60+     auto  it = logical_domain_.erase (split->in ()).second ;
61+     logical_domain_.insert (it, split->outer (), std::monostate ());
62+     logical_domain_.insert (it, split->inner (), std::monostate ());
63+   }
64+ 
65+   void  mergeId (Merge* merge) {
66+     auto  outer_it = logical_domain_.erase (merge->outer ()).second ;
67+     logical_domain_.insert (outer_it, merge->out (), std::monostate ());
68+     logical_domain_.erase (merge->inner ());
69+   }
70+ 
71+   void  updateRFactorDomain (Expr* expr) {
72+     if  (expr->isA <Split>()) {
73+       splitId (expr->as <Split>());
8074    }
81-     pos = std::find (logical_domain_.begin (), logical_domain_.end (), replace0);
82-     logical_domain_.erase (pos);
83-     if  (replace1 != nullptr ) {
84-       pos = std::find (logical_domain_.begin (), logical_domain_.end (), replace1);
85-       NVF_ERROR (
86-           pos != logical_domain_.end (),
87-           " Wanted to replace "  ,
88-           replace1->toString (),
89-           "  but it's not in the logical domain."  );
90-       logical_domain_.erase (pos);
75+     if  (expr->isA <Merge>()) {
76+       mergeId (expr->as <Merge>());
9177    }
78+     NVF_ERROR (" Unrecognized expression: "  , expr->toString ());
9279  }
9380
9481  //  Took a good bit of this from ReplayTransformations::handle(Split...)
@@ -153,7 +140,7 @@ class ReplayRFactor : public ReplayTransformations {
153140    id_map_[s->inner ()] = idi;
154141
155142    if  (static_logical_ids_.count (s->in ())) {
156-       updateRFactorDomain (s-> in (),  nullptr , s-> outer (), s-> inner () );
143+       updateRFactorDomain (s);
157144    }
158145  }
159146
@@ -211,7 +198,7 @@ class ReplayRFactor : public ReplayTransformations {
211198              static_logical_ids_.count (m->outer ()),
212199          " If one input to a merge is a static logical id, the other must be " 
213200          " as well."  );
214-       updateRFactorDomain (m-> outer (), m-> inner (), m-> out (),  nullptr );
201+       updateRFactorDomain (m);
215202    }
216203  }
217204
@@ -241,7 +228,7 @@ class ReplayRFactor : public ReplayTransformations {
241228  //  The updated domain matching the producer's logical domain. This rfactor
242229  //  domain is relative to the iter domains in the origianl_domain and must be
243230  //  updated to grab the mapped id's later.
244-   std::vector <IterDomain*> logical_domain_;
231+   LinkedHashMap <IterDomain*, std::monostate > logical_domain_;
245232
246233  ReplayRFactor (
247234      //  Original domain the rfactor is in reference to.
@@ -257,8 +244,7 @@ class ReplayRFactor : public ReplayTransformations {
257244      std::unordered_set<IterDomain*> static_logical_ids)
258245      : ReplayTransformations(original_domain->loop (), std::move(id_map)),
259246        rfactor_axes_(std::move(rfactor_axes)),
260-         static_logical_ids_(std::move(static_logical_ids)),
261-         logical_domain_(original_domain->logical ()) {
247+         static_logical_ids_(std::move(static_logical_ids)) {
262248    const  auto  all_dep_vals = DependencyCheck::getAllValsBetween (
263249        {original_domain->maybeRoot ().begin (),
264250         original_domain->maybeRoot ().end ()},
@@ -267,8 +253,19 @@ class ReplayRFactor : public ReplayTransformations {
267253    auto  all_dep_ids = ir_utils::filterByType<IterDomain>(all_dep_vals);
268254    rfactor_dep_ids_.insert (all_dep_ids.begin (), all_dep_ids.end ());
269255
256+     for  (IterDomain* id : original_domain->logical ()) {
257+       logical_domain_.pushBack (id, std::monostate ());
258+     }
259+ 
270260    setErrorOnFailure (false );
271261  }
262+ 
263+   std::vector<IterDomain*> logical () const  {
264+     auto  logical_ids = std::views::keys (logical_domain_);
265+     std::vector<IterDomain*> transformed_logical (
266+         logical_ids.begin (), logical_ids.end ());
267+     return  transformed_logical;
268+   }
272269};
273270
274271} //  namespace
@@ -421,10 +418,11 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
421418  //  Specify the logical domain of the producer which will match the consumer
422419  //  root domain.
423420  std::vector<IterDomain*> new_producer_logical_domain;
424-   new_producer_logical_domain.reserve (replay_rfactor.logical_domain_ .size ());
421+   std::vector<IterDomain*> transformed_original_logical =
422+       replay_rfactor.logical ();
425423  std::transform (
426-       replay_rfactor. logical_domain_ .begin (),
427-       replay_rfactor. logical_domain_ .end (),
424+       transformed_original_logical .begin (),
425+       transformed_original_logical .end (),
428426      std::back_inserter (new_producer_logical_domain),
429427      [&](IterDomain* id) {
430428        auto  replayed_id_it = original_to_producer_id_map.find (id);
0 commit comments