Skip to content

Commit 07c8755

Browse files
committed
Add support for automatic scheduler in direct bindings
* Define HeuristicParams struct * Update manual_execute to take launch_params, compile_params, and scheduler_type from HeuristicParams * Add can_schedule, find_compatible_schedulers, and schedule operations * Create tests for pointwise, reduction, and inner persistent schedulers
1 parent 922d6bd commit 07c8755

File tree

8 files changed

+327
-7
lines changed

8 files changed

+327
-7
lines changed

csrc/runtime/executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class KernelExecutor : public ExecutorAbstract {
105105
const KernelArgumentHolder& args = {},
106106
const LaunchParams& launch_constraints = LaunchParams(),
107107
CompileParams compile_params = CompileParams(),
108-
SchedulerType sceduler_type = SchedulerType::None);
108+
SchedulerType scheduler_type = SchedulerType::None);
109109

110110
NVF_API KernelArgumentHolder
111111
run(KernelArgumentHolder args,

csrc/scheduler/heuristic.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class HeuristicDataCache;
2121
// Top-level class representing heuristic parameters. Most schedulers
2222
// have their own subclasses to have their specific parameters, except
2323
// for ExprEval schedulers.
24-
class HeuristicParams : public PolymorphicBase {
24+
class NVF_API HeuristicParams : public PolymorphicBase {
2525
public:
2626
std::string tag = "";
2727

csrc/scheduler/registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ namespace Schedule {
7979

8080
//! External access for canSchedule utilities through SchedulerEntry
8181
//! to avoid exposing a single function to the namespace
82-
bool canSchedule(
82+
NVF_API bool canSchedule(
8383
SchedulerType sh,
8484
Fusion* fusion,
8585
SchedulerRuntimeInfo& runtime_info,

csrc/scheduler/runtime_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class SchedulerRuntimeInfo : public NonCopyable {
4141
//! The index type of forced_index_type is used if given, no matter
4242
//! how large the actual arguments and fusion tensors
4343
//! are. CORRECTNESS IS NOT GUARANTEED.
44-
SchedulerRuntimeInfo(
44+
NVF_API SchedulerRuntimeInfo(
4545
Fusion* complete_fusion,
4646
KernelArgumentHolder args,
4747
PrecomputedValues* precomputed_values = nullptr,

python/nvfuser_direct/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ def execute(
322322
_disable_options=_disable_options,
323323
)
324324

325-
def manual_execute(self, inputs):
325+
def manual_execute(
326+
self, inputs, heuristic_params: Optional[HeuristicParams] = None
327+
):
326328
"""
327329
Execute the fusion with the given inputs.
328330
@@ -344,8 +346,21 @@ def manual_execute(self, inputs):
344346
self.ke = KernelExecutor()
345347

346348
if not self.ke.is_compiled():
347-
self.ke.compile(self.fusion, inputs)
349+
if heuristic_params is not None:
350+
self.ke.compile(
351+
self.fusion,
352+
inputs,
353+
heuristic_params.lparams,
354+
heuristic_params.cparams,
355+
heuristic_params.scheduler_type,
356+
)
357+
else:
358+
self.ke.compile(self.fusion, inputs)
348359

360+
if heuristic_params is not None:
361+
return self.ke.run(
362+
inputs, heuristic_params.lparams, heuristic_params.cparams
363+
)
349364
return self.ke.run(inputs)
350365

351366
def last_repro_script(self) -> str:

python/python_direct/heuristic_params.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ void bindHeuristicParams(py::module& nvfuser) {
152152
"include_paths", &CompileParams::include_paths, R"(
153153
The additional include paths to use for the kernel.
154154
)");
155+
156+
py::class_<HeuristicParams> heuristic_parameters(
157+
nvfuser, "HeuristicParams", py::module_local());
158+
heuristic_parameters.def(
159+
"__repr__", [](const HeuristicParams& self) { return self.toString(); });
160+
heuristic_parameters.def("__eq__", &HeuristicParams::sameAs, R"(
161+
Whether the heuristic parameters are the same.
162+
)");
163+
heuristic_parameters.def_readwrite("lparams", &HeuristicParams::lparams, R"(
164+
The launch parameters for the kernel.
165+
)");
166+
heuristic_parameters.def_readwrite("cparams", &HeuristicParams::cparams, R"(
167+
The compile parameters for the kernel.
168+
)");
169+
heuristic_parameters.def_readonly(
170+
"scheduler_type", &HeuristicParams::scheduler_type, R"(
171+
The type of scheduler that generated these parameters.
172+
)");
173+
heuristic_parameters.def("hash", &HeuristicParams::hash, R"(
174+
The hash of the heuristic parameters.
175+
)");
176+
177+
py::class_<PointwiseParams, HeuristicParams> pointwise(
178+
nvfuser, "PointwiseParams", py::module_local());
179+
pointwise.def(py::init());
180+
pointwise.def(
181+
"__repr__", [](const PointwiseParams& self) { return self.toString(); });
182+
183+
py::class_<ReductionParams, HeuristicParams> reduction(
184+
nvfuser, "ReductionParams", py::module_local());
185+
reduction.def(py::init());
186+
reduction.def(
187+
"__repr__", [](const ReductionParams& self) { return self.toString(); });
155188
}
156189

157190
} // namespace nvfuser::python

python/python_direct/schedule.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
*/
77
// clang-format on
88
#include <bindings.h>
9+
#include <direct_utils.h>
10+
#include <options.h>
11+
#include <scheduler/registry.h>
12+
#include <scheduler/runtime_info.h>
13+
#include <scheduler/scheduler_types.h>
914
#include <scheduler/tools/inlining.h>
1015
#include <scheduler/utils.h>
1116
#include <transform_replay.h>
@@ -153,6 +158,122 @@ void bindTensorviewScheduleOps(py::module_& schedule) {
153158
None
154159
)",
155160
py::arg("selected_tensors") = std::vector<TensorView*>());
161+
162+
schedule.def(
163+
"can_schedule",
164+
[](Fusion* fusion,
165+
SchedulerType scheduler_type,
166+
const py::iterable& inputs) {
167+
// Enable collection of messages from canScheduleRejectReason
168+
DebugDumpOptionsGuard debug_dump_options_guard;
169+
DebugDumpOptionsGuard::getCurOptions().set(
170+
DebugDumpOption::FusionSegmenterLog);
171+
172+
// Send debug messages to stringstream
173+
std::stringstream ss;
174+
DebugStreamGuard dsg(ss);
175+
176+
// Create runtime info from inputs
177+
auto args = from_pyiterable(inputs);
178+
SchedulerRuntimeInfo runtime_info(fusion, args);
179+
180+
bool can_schedule =
181+
Schedule::canSchedule(scheduler_type, fusion, runtime_info);
182+
return std::make_tuple(can_schedule, ss.str());
183+
},
184+
py::arg("fusion"),
185+
py::arg("scheduler_type"),
186+
py::arg("inputs"),
187+
R"(
188+
Check if a scheduler can schedule the given fusion with the provided inputs.
189+
190+
Parameters
191+
----------
192+
fusion : Fusion
193+
The fusion to check.
194+
scheduler_type : SchedulerType
195+
The type of scheduler to check.
196+
inputs : iterable
197+
The input tensors/values for the fusion.
198+
199+
Returns
200+
-------
201+
tuple of (bool, str)
202+
A tuple containing:
203+
- bool: True if the scheduler can schedule the fusion, False otherwise.
204+
- str: Debug message explaining why the scheduler was accepted or rejected.
205+
)");
206+
207+
schedule.def(
208+
"find_compatible_schedulers",
209+
[](Fusion* fusion, const py::iterable& inputs) {
210+
// Create runtime info from inputs
211+
auto args = from_pyiterable(inputs);
212+
SchedulerRuntimeInfo runtime_info(fusion, args);
213+
214+
std::vector<SchedulerType> compatible_schedulers;
215+
216+
// Check all scheduler types except None
217+
for (const auto& scheduler_type : all_heuristics_in_priority_order) {
218+
if (scheduler_type != SchedulerType::None &&
219+
Schedule::canSchedule(scheduler_type, fusion, runtime_info)) {
220+
compatible_schedulers.push_back(scheduler_type);
221+
}
222+
}
223+
224+
return compatible_schedulers;
225+
},
226+
py::arg("fusion"),
227+
py::arg("inputs"),
228+
R"(
229+
Find all schedulers compatible with the given fusion and inputs.
230+
231+
Parameters
232+
----------
233+
fusion : Fusion
234+
The fusion to check.
235+
inputs : iterable
236+
The input tensors/values for the fusion.
237+
238+
Returns
239+
-------
240+
list of SchedulerType
241+
A list of scheduler types that can schedule the fusion.
242+
)");
243+
244+
schedule.def(
245+
"schedule",
246+
[](Fusion* fusion,
247+
SchedulerType scheduler_type,
248+
const py::iterable& inputs) {
249+
auto args = from_pyiterable(inputs);
250+
return SchedulerEntry::scheduleWith(
251+
fusion, scheduler_type, args, /*validate_scheduler=*/true);
252+
},
253+
py::arg("fusion"),
254+
py::arg("scheduler_type"),
255+
py::arg("inputs"),
256+
R"(
257+
Schedule the fusion with the specified scheduler type.
258+
259+
Parameters
260+
----------
261+
fusion : Fusion
262+
The fusion to schedule.
263+
scheduler_type : SchedulerType
264+
The type of scheduler to use.
265+
inputs : iterable
266+
The input tensors/values for the fusion.
267+
268+
Returns
269+
-------
270+
HeuristicParams
271+
The heuristics for the scheduled fusion.
272+
273+
Notes
274+
-----
275+
This function will raise an error if the scheduler cannot schedule the fusion.
276+
)");
156277
}
157278

158279
} // namespace

0 commit comments

Comments
 (0)