|
6 | 6 | */ |
7 | 7 | // clang-format on |
8 | 8 | #include <bindings.h> |
| 9 | +#include <direct_utils.h> |
| 10 | +#include <options.h> |
| 11 | +#include <scheduler/registry.h> |
| 12 | +#include <scheduler/runtime_info.h> |
| 13 | +#include <scheduler/scheduler_types.h> |
9 | 14 | #include <scheduler/tools/inlining.h> |
10 | 15 | #include <scheduler/utils.h> |
11 | 16 | #include <transform_replay.h> |
@@ -153,6 +158,122 @@ void bindTensorviewScheduleOps(py::module_& schedule) { |
153 | 158 | None |
154 | 159 | )", |
155 | 160 | py::arg("selected_tensors") = std::vector<TensorView*>()); |
| 161 | + |
| 162 | + schedule.def( |
| 163 | + "can_schedule", |
| 164 | + [](Fusion* fusion, |
| 165 | + SchedulerType scheduler_type, |
| 166 | + const py::iterable& inputs) { |
| 167 | + // Enable collection of messages from canScheduleRejectReason |
| 168 | + DebugDumpOptionsGuard debug_dump_options_guard; |
| 169 | + DebugDumpOptionsGuard::getCurOptions().set( |
| 170 | + DebugDumpOption::FusionSegmenterLog); |
| 171 | + |
| 172 | + // Send debug messages to stringstream |
| 173 | + std::stringstream ss; |
| 174 | + DebugStreamGuard dsg(ss); |
| 175 | + |
| 176 | + // Create runtime info from inputs |
| 177 | + auto args = from_pyiterable(inputs); |
| 178 | + SchedulerRuntimeInfo runtime_info(fusion, args); |
| 179 | + |
| 180 | + bool can_schedule = |
| 181 | + Schedule::canSchedule(scheduler_type, fusion, runtime_info); |
| 182 | + return std::make_tuple(can_schedule, ss.str()); |
| 183 | + }, |
| 184 | + py::arg("fusion"), |
| 185 | + py::arg("scheduler_type"), |
| 186 | + py::arg("inputs"), |
| 187 | + R"( |
| 188 | + Check if a scheduler can schedule the given fusion with the provided inputs. |
| 189 | +
|
| 190 | + Parameters |
| 191 | + ---------- |
| 192 | + fusion : Fusion |
| 193 | + The fusion to check. |
| 194 | + scheduler_type : SchedulerType |
| 195 | + The type of scheduler to check. |
| 196 | + inputs : iterable |
| 197 | + The input tensors/values for the fusion. |
| 198 | +
|
| 199 | + Returns |
| 200 | + ------- |
| 201 | + tuple of (bool, str) |
| 202 | + A tuple containing: |
| 203 | + - bool: True if the scheduler can schedule the fusion, False otherwise. |
| 204 | + - str: Debug message explaining why the scheduler was accepted or rejected. |
| 205 | + )"); |
| 206 | + |
| 207 | + schedule.def( |
| 208 | + "find_compatible_schedulers", |
| 209 | + [](Fusion* fusion, const py::iterable& inputs) { |
| 210 | + // Create runtime info from inputs |
| 211 | + auto args = from_pyiterable(inputs); |
| 212 | + SchedulerRuntimeInfo runtime_info(fusion, args); |
| 213 | + |
| 214 | + std::vector<SchedulerType> compatible_schedulers; |
| 215 | + |
| 216 | + // Check all scheduler types except None |
| 217 | + for (const auto& scheduler_type : all_heuristics_in_priority_order) { |
| 218 | + if (scheduler_type != SchedulerType::None && |
| 219 | + Schedule::canSchedule(scheduler_type, fusion, runtime_info)) { |
| 220 | + compatible_schedulers.push_back(scheduler_type); |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + return compatible_schedulers; |
| 225 | + }, |
| 226 | + py::arg("fusion"), |
| 227 | + py::arg("inputs"), |
| 228 | + R"( |
| 229 | + Find all schedulers compatible with the given fusion and inputs. |
| 230 | +
|
| 231 | + Parameters |
| 232 | + ---------- |
| 233 | + fusion : Fusion |
| 234 | + The fusion to check. |
| 235 | + inputs : iterable |
| 236 | + The input tensors/values for the fusion. |
| 237 | +
|
| 238 | + Returns |
| 239 | + ------- |
| 240 | + list of SchedulerType |
| 241 | + A list of scheduler types that can schedule the fusion. |
| 242 | + )"); |
| 243 | + |
| 244 | + schedule.def( |
| 245 | + "schedule", |
| 246 | + [](Fusion* fusion, |
| 247 | + SchedulerType scheduler_type, |
| 248 | + const py::iterable& inputs) { |
| 249 | + auto args = from_pyiterable(inputs); |
| 250 | + return SchedulerEntry::scheduleWith( |
| 251 | + fusion, scheduler_type, args, /*validate_scheduler=*/true); |
| 252 | + }, |
| 253 | + py::arg("fusion"), |
| 254 | + py::arg("scheduler_type"), |
| 255 | + py::arg("inputs"), |
| 256 | + R"( |
| 257 | + Schedule the fusion with the specified scheduler type. |
| 258 | +
|
| 259 | + Parameters |
| 260 | + ---------- |
| 261 | + fusion : Fusion |
| 262 | + The fusion to schedule. |
| 263 | + scheduler_type : SchedulerType |
| 264 | + The type of scheduler to use. |
| 265 | + inputs : iterable |
| 266 | + The input tensors/values for the fusion. |
| 267 | +
|
| 268 | + Returns |
| 269 | + ------- |
| 270 | + HeuristicParams |
| 271 | + The heuristics for the scheduled fusion. |
| 272 | +
|
| 273 | + Notes |
| 274 | + ----- |
| 275 | + This function will raise an error if the scheduler cannot schedule the fusion. |
| 276 | + )"); |
156 | 277 | } |
157 | 278 |
|
158 | 279 | } // namespace |
|
0 commit comments