|
5 | 5 | * SPDX-License-Identifier: BSD-3-Clause |
6 | 6 | */ |
7 | 7 | // clang-format on |
| 8 | +#include <ir/builder.h> |
8 | 9 | #include <ir/utils.h> |
9 | 10 | #include <logical_domain_map.h> |
10 | 11 | #include <runtime/executor_kernel_arg.h> |
@@ -1170,6 +1171,69 @@ bool SchedulerTopologyChecker::hasCyclicReshape(Fusion* fusion) { |
1170 | 1171 | return false; |
1171 | 1172 | } |
1172 | 1173 |
|
| 1174 | +namespace { |
| 1175 | + |
| 1176 | +// Two sequences of constant outer-split factors are considered compatible |
| 1177 | +// if they share the same leading prefix up to the length of the shorter one. |
| 1178 | +// That is, either they are identical (same length) or the shorter sequence |
| 1179 | +// matches the first N elements of the longer one. |
| 1180 | +bool areConstExtentSequencesCompatible( |
| 1181 | + const std::vector<int64_t>& a, |
| 1182 | + const std::vector<int64_t>& b) { |
| 1183 | + if (a.size() <= b.size()) { |
| 1184 | + return a == std::vector<int64_t>(b.begin(), b.begin() + a.size()); |
| 1185 | + } else { |
| 1186 | + return b == std::vector<int64_t>(a.begin(), a.begin() + b.size()); |
| 1187 | + } |
| 1188 | +} |
| 1189 | + |
| 1190 | +} // namespace |
| 1191 | + |
| 1192 | +// Reshape is expressed as merges and outer splits by constant factors. |
| 1193 | +// Merges are replay-safe and do not introduce conflicts during transform |
| 1194 | +// propagation. Compatibility therefore hinges on the constant outer-split |
| 1195 | +// factors: if the leading prefix of factors matches, the splits can be |
| 1196 | +// replayed across reshapes. |
| 1197 | +bool SchedulerTopologyChecker::hasIncompatibleReshape(Fusion* fusion) { |
| 1198 | + const auto reshape_ops = ir_utils::getOpsOfType<ReshapeOp>(fusion); |
| 1199 | + if (reshape_ops.size() < 2) { |
| 1200 | + return false; |
| 1201 | + } |
| 1202 | + |
| 1203 | + // Collect constant outer-split factors of the reshape output, ordered |
| 1204 | + // from outer-most to inner-most. |
| 1205 | + auto collect_extent_prefix = [](ReshapeOp* reshape) { |
| 1206 | + auto reshape_out = reshape->out()->as<TensorView>(); |
| 1207 | + std::vector<int64_t> outer_split_factors; |
| 1208 | + for (auto* id : reshape_out->getLogicalDomain()) { |
| 1209 | + if (auto* def = id->definition()) { |
| 1210 | + if (auto* split = dynamic_cast<Split*>(def)) { |
| 1211 | + if (!split->innerSplit() && split->outer() == id && |
| 1212 | + split->factor()->isConstInt()) { |
| 1213 | + outer_split_factors.push_back( |
| 1214 | + split->factor()->evaluate().as<int64_t>()); |
| 1215 | + } |
| 1216 | + } |
| 1217 | + } |
| 1218 | + } |
| 1219 | + |
| 1220 | + return outer_split_factors; |
| 1221 | + }; |
| 1222 | + |
| 1223 | + for (const auto i : arange(std::ssize(reshape_ops) - 1)) { |
| 1224 | + const auto& out_ids_i_extents = collect_extent_prefix(reshape_ops[i]); |
| 1225 | + for (const auto j : arange(i + 1, std::ssize(reshape_ops))) { |
| 1226 | + const auto& out_ids_j_extents = collect_extent_prefix(reshape_ops[j]); |
| 1227 | + if (!areConstExtentSequencesCompatible( |
| 1228 | + out_ids_i_extents, out_ids_j_extents)) { |
| 1229 | + return true; |
| 1230 | + } |
| 1231 | + } |
| 1232 | + } |
| 1233 | + |
| 1234 | + return false; |
| 1235 | +} |
| 1236 | + |
1173 | 1237 | } // namespace registry_utils |
1174 | 1238 |
|
1175 | 1239 | } // namespace nvfuser |
0 commit comments