Skip to content

Commit 80e2206

Browse files
committed
only validate new_loop updated in replay
1 parent fb7f4fa commit 80e2206

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

csrc/transform_replay.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,13 +867,20 @@ std::pair<TensorDomain*, int64_t> TransformReplay::replayCasP(
867867
}
868868

869869
if (!opt.replay_allocation) {
870+
// Duplicate TensorDomain to avoid validation run on existing TensorDomain.
870871
TensorDomain* replayed = IrBuilder::createInContainer<TensorDomain>(
871872
consumer->container(),
872873
consumer->getRootDomain(),
873874
consumer->getLogicalDomain(),
874875
consumer->getAllocationDomain(),
875-
new_loop,
876-
consumer->domain()->contiguity());
876+
consumer->getLoopDomain(),
877+
/*alternate_loop_domain=*/std::vector<IterDomain*>(),
878+
consumer->domain()->contiguity(),
879+
/*additiona_ids=*/std::vector<IterDomain*>(),
880+
/*skip_validation=*/true);
881+
882+
// update loop domain, this ensures we run validation on new_loop.
883+
replayed->setLoopDomain(new_loop);
877884

878885
return {replayed, consumer_pos};
879886
}

0 commit comments

Comments
 (0)