Skip to content

Commit 30bcca8

Browse files
committed
add SchedulerTopologyChecker::hasIncompatibleReshape
1 parent 0eec90c commit 30bcca8

File tree

7 files changed

+438
-5
lines changed

7 files changed

+438
-5
lines changed

csrc/scheduler/pointwise.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
879879

880880
if (!ir_utils::getReshapeOps(fusion).empty()) {
881881
ComputeAtMap ca_map(fusion);
882+
std::cout << "fusion before propagateReshapeTransforms" << std::endl;
883+
fusion->print();
884+
std::cout << ca_map.toString() << std::endl;
882885
// Propagate reshape transforms through the graph, expecially the reference.
883886
scheduler_utils::propagateReshapeTransforms(fusion, ca_map);
884887

csrc/scheduler/registry.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
9797
return false;
9898
}
9999

100+
if (registry_utils::SchedulerTopologyChecker::hasIncompatibleReshape(
101+
fusion)) {
102+
scheduler_debug_utils::canScheduleRejectReason(
103+
scheduler_type, "Fusion has cyclic reshapes.");
104+
return false;
105+
}
106+
100107
if (registry_utils::SchedulerTopologyChecker::
101108
rejectScheduleFusionGlobalBufferRequirement(fusion, scheduler_type)) {
102109
scheduler_debug_utils::canScheduleRejectReason(

csrc/scheduler/registry_utils.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* SPDX-License-Identifier: BSD-3-Clause
66
*/
77
// clang-format on
8+
#include <ir/builder.h>
89
#include <ir/utils.h>
910
#include <logical_domain_map.h>
1011
#include <runtime/executor_kernel_arg.h>
@@ -1170,6 +1171,69 @@ bool SchedulerTopologyChecker::hasCyclicReshape(Fusion* fusion) {
11701171
return false;
11711172
}
11721173

1174+
namespace {
1175+
1176+
// Two sequences of constant outer-split factors are considered compatible
1177+
// if they share the same leading prefix up to the length of the shorter one.
1178+
// That is, either they are identical (same length) or the shorter sequence
1179+
// matches the first N elements of the longer one.
1180+
bool areConstExtentSequencesCompatible(
1181+
const std::vector<int64_t>& a,
1182+
const std::vector<int64_t>& b) {
1183+
if (a.size() <= b.size()) {
1184+
return a == std::vector<int64_t>(b.begin(), b.begin() + a.size());
1185+
} else {
1186+
return b == std::vector<int64_t>(a.begin(), a.begin() + b.size());
1187+
}
1188+
}
1189+
1190+
} // namespace
1191+
1192+
// Reshape is expressed as merges and outer splits by constant factors.
1193+
// Merges are replay-safe and do not introduce conflicts during transform
1194+
// propagation. Compatibility therefore hinges on the constant outer-split
1195+
// factors: if the leading prefix of factors matches, the splits can be
1196+
// replayed across reshapes.
1197+
bool SchedulerTopologyChecker::hasIncompatibleReshape(Fusion* fusion) {
1198+
const auto reshape_ops = ir_utils::getOpsOfType<ReshapeOp>(fusion);
1199+
if (reshape_ops.size() < 2) {
1200+
return false;
1201+
}
1202+
1203+
// Collect constant outer-split factors of the reshape output, ordered
1204+
// from outer-most to inner-most.
1205+
auto collect_extent_prefix = [](ReshapeOp* reshape) {
1206+
auto reshape_out = reshape->out()->as<TensorView>();
1207+
std::vector<int64_t> outer_split_factors;
1208+
for (auto* id : reshape_out->getLogicalDomain()) {
1209+
if (auto* def = id->definition()) {
1210+
if (auto* split = dynamic_cast<Split*>(def)) {
1211+
if (!split->innerSplit() && split->outer() == id &&
1212+
split->factor()->isConstInt()) {
1213+
outer_split_factors.push_back(
1214+
split->factor()->evaluate().as<int64_t>());
1215+
}
1216+
}
1217+
}
1218+
}
1219+
1220+
return outer_split_factors;
1221+
};
1222+
1223+
for (const auto i : arange(std::ssize(reshape_ops) - 1)) {
1224+
const auto& out_ids_i_extents = collect_extent_prefix(reshape_ops[i]);
1225+
for (const auto j : arange(i + 1, std::ssize(reshape_ops))) {
1226+
const auto& out_ids_j_extents = collect_extent_prefix(reshape_ops[j]);
1227+
if (!areConstExtentSequencesCompatible(
1228+
out_ids_i_extents, out_ids_j_extents)) {
1229+
return true;
1230+
}
1231+
}
1232+
}
1233+
1234+
return false;
1235+
}
1236+
11731237
} // namespace registry_utils
11741238

11751239
} // namespace nvfuser

csrc/scheduler/registry_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ class SchedulerTopologyChecker {
123123
// propagateReshapeTransforms won't work as it won't find any
124124
// terminating reshape IDs.
125125
static bool hasCyclicReshape(Fusion* fusion);
126+
127+
static bool hasIncompatibleReshape(Fusion* fusion);
126128
};
127129

128130
} // namespace registry_utils

csrc/scheduler/utils.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,12 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {
23662366
}
23672367
}
23682368
}
2369+
std::cout << "transformed_disjoint_sets: " << transformed_disjoint_sets.size()
2370+
<< std::endl;
2371+
for (auto set : transformed_disjoint_sets) {
2372+
std::cout << " transformed_disjoint_sets: " << set->toString()
2373+
<< std::endl;
2374+
}
23692375

23702376
std::unordered_set<IterDomain*> terminating_reshape_dims;
23712377
for (const auto& disjoint_set_shared_ptr :
@@ -2390,7 +2396,11 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {
23902396
terminating_reshape_dims.emplace(id);
23912397
}
23922398
}
2393-
2399+
std::cout << "terminating_reshape_dims: " << terminating_reshape_dims.size()
2400+
<< std::endl;
2401+
for (auto set : terminating_reshape_dims) {
2402+
std::cout << " terminating_reshape_dims: " << set->toString() << std::endl;
2403+
}
23942404
// If iter domains are involved in any transformation from root domains to
23952405
// logical domains they should be considered "contaminated".
23962406
for (auto tv : fusion->allTvs()) {
@@ -2455,8 +2465,11 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {
24552465
}
24562466

24572467
// Propagate the view transformations
2468+
std::cout << "before reorder " << tv->toString() << std::endl;
24582469
tv->reorder(old2new);
24592470
//! Propagate current transformations on from_tv to all graphs
2471+
std::cout << "transformPropagateToAllFrom " << tv->toString() << " @ "
2472+
<< (int64_t)old2new.size() << std::endl;
24602473
transformPropagateToAllFrom(tv, (int64_t)old2new.size());
24612474

24622475
// Propgating the transforms will not replay the DIDx parallelization, so we

csrc/transform_replay.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,8 @@ void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) {
12011201
debug() << "TransformPropagator::propagateP2C" << std::endl;
12021202
debug() << " from: " << from << " @ " << pos << std::endl;
12031203
debug() << " to: " << to << std::endl;
1204+
std::cout << " from transforms: " << std::endl;
1205+
from->printTransforms();
12041206
}
12051207
if (new_pos < 0) {
12061208
auto replay = TransformReplay::replayCasP(
@@ -1255,6 +1257,9 @@ void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) {
12551257

12561258
TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) {
12571259
replayed_pos_[from] = wrapDim(pos, from->nDims() + 1);
1260+
std::cout << "TransformPropagator::TransformPropagator" << std::endl;
1261+
std::cout << " from: " << from->toString() << " @ " << replayed_pos_[from]
1262+
<< std::endl;
12581263
}
12591264

12601265
void MostInlinedTransformPropagator::propagateC2P(

0 commit comments

Comments
 (0)