Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/runtime/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class KernelExecutor : public ExecutorAbstract {
const KernelArgumentHolder& args = {},
const LaunchParams& launch_constraints = LaunchParams(),
CompileParams compile_params = CompileParams(),
SchedulerType sceduler_type = SchedulerType::None);
SchedulerType scheduler_type = SchedulerType::None);

NVF_API KernelArgumentHolder
run(KernelArgumentHolder args,
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class HeuristicDataCache;
// Top-level class representing heuristic parameters. Most schedulers
// have their own subclasses to have their specific parameters, except
// for ExprEval schedulers.
class HeuristicParams : public PolymorphicBase {
class NVF_API HeuristicParams : public PolymorphicBase {
public:
std::string tag = "";

Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace Schedule {

//! External access for canSchedule utilities through SchedulerEntry
//! to avoid exposing a single function to the namespace
bool canSchedule(
NVF_API bool canSchedule(
SchedulerType sh,
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/runtime_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SchedulerRuntimeInfo : public NonCopyable {
//! The index type of forced_index_type is used if given, no matter
//! how large the actual arguments and fusion tensors
//! are. CORRECTNESS IS NOT GUARANTEED.
SchedulerRuntimeInfo(
NVF_API SchedulerRuntimeInfo(
Fusion* complete_fusion,
KernelArgumentHolder args,
PrecomputedValues* precomputed_values = nullptr,
Expand Down
19 changes: 17 additions & 2 deletions python/nvfuser_direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def execute(
_disable_options=_disable_options,
)

def manual_execute(self, inputs):
def manual_execute(
self, inputs, heuristic_params: Optional[HeuristicParams] = None
):
"""
Execute the fusion with the given inputs.

Expand All @@ -344,8 +346,21 @@ def manual_execute(self, inputs):
self.ke = KernelExecutor()

if not self.ke.is_compiled():
self.ke.compile(self.fusion, inputs)
if heuristic_params is not None:
self.ke.compile(
self.fusion,
inputs,
heuristic_params.lparams,
heuristic_params.cparams,
heuristic_params.scheduler_type,
)
else:
self.ke.compile(self.fusion, inputs)

if heuristic_params is not None:
return self.ke.run(
inputs, heuristic_params.lparams, heuristic_params.cparams
)
return self.ke.run(inputs)

def last_repro_script(self) -> str:
Expand Down
33 changes: 33 additions & 0 deletions python/python_direct/heuristic_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,39 @@ void bindHeuristicParams(py::module& nvfuser) {
"include_paths", &CompileParams::include_paths, R"(
The additional include paths to use for the kernel.
)");

py::class_<HeuristicParams> heuristic_parameters(
nvfuser, "HeuristicParams", py::module_local());
heuristic_parameters.def(
"__repr__", [](const HeuristicParams& self) { return self.toString(); });
heuristic_parameters.def("__eq__", &HeuristicParams::sameAs, R"(
Whether the heuristic parameters are the same.
)");
heuristic_parameters.def_readwrite("lparams", &HeuristicParams::lparams, R"(
The launch parameters for the kernel.
)");
heuristic_parameters.def_readwrite("cparams", &HeuristicParams::cparams, R"(
The compile parameters for the kernel.
)");
heuristic_parameters.def_readonly(
"scheduler_type", &HeuristicParams::scheduler_type, R"(
The type of scheduler that generated these parameters.
)");
heuristic_parameters.def("hash", &HeuristicParams::hash, R"(
The hash of the heuristic parameters.
)");

py::class_<PointwiseParams, HeuristicParams> pointwise(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will added fields to PointwiseParams and ReductionParams later but the class definition was necessary to pass tests.

nvfuser, "PointwiseParams", py::module_local());
pointwise.def(py::init());
pointwise.def(
"__repr__", [](const PointwiseParams& self) { return self.toString(); });

py::class_<ReductionParams, HeuristicParams> reduction(
nvfuser, "ReductionParams", py::module_local());
reduction.def(py::init());
reduction.def(
"__repr__", [](const ReductionParams& self) { return self.toString(); });
}

} // namespace nvfuser::python
121 changes: 121 additions & 0 deletions python/python_direct/schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
*/
// clang-format on
#include <bindings.h>
#include <direct_utils.h>
#include <options.h>
#include <scheduler/registry.h>
#include <scheduler/runtime_info.h>
#include <scheduler/scheduler_types.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/utils.h>
#include <transform_replay.h>
Expand Down Expand Up @@ -153,6 +158,122 @@ void bindTensorviewScheduleOps(py::module_& schedule) {
None
)",
py::arg("selected_tensors") = std::vector<TensorView*>());

schedule.def(
"can_schedule",
[](Fusion* fusion,
SchedulerType scheduler_type,
const py::iterable& inputs) {
// Enable collection of messages from canScheduleRejectReason
DebugDumpOptionsGuard debug_dump_options_guard;
DebugDumpOptionsGuard::getCurOptions().set(
DebugDumpOption::FusionSegmenterLog);

// Send debug messages to stringstream
std::stringstream ss;
DebugStreamGuard dsg(ss);

// Create runtime info from inputs
auto args = from_pyiterable(inputs);
SchedulerRuntimeInfo runtime_info(fusion, args);

bool can_schedule =
Schedule::canSchedule(scheduler_type, fusion, runtime_info);
return std::make_tuple(can_schedule, ss.str());
},
py::arg("fusion"),
py::arg("scheduler_type"),
py::arg("inputs"),
R"(
Check if a scheduler can schedule the given fusion with the provided inputs.

Parameters
----------
fusion : Fusion
The fusion to check.
scheduler_type : SchedulerType
The type of scheduler to check.
inputs : iterable
The input tensors/values for the fusion.

Returns
-------
tuple of (bool, str)
A tuple containing:
- bool: True if the scheduler can schedule the fusion, False otherwise.
- str: Debug message explaining why the scheduler was accepted or rejected.
)");

schedule.def(
"find_compatible_schedulers",
[](Fusion* fusion, const py::iterable& inputs) {
// Create runtime info from inputs
auto args = from_pyiterable(inputs);
SchedulerRuntimeInfo runtime_info(fusion, args);

std::vector<SchedulerType> compatible_schedulers;

// Check all scheduler types except None
for (const auto& scheduler_type : all_heuristics_in_priority_order) {
if (scheduler_type != SchedulerType::None &&
Schedule::canSchedule(scheduler_type, fusion, runtime_info)) {
compatible_schedulers.push_back(scheduler_type);
}
}

return compatible_schedulers;
},
py::arg("fusion"),
py::arg("inputs"),
R"(
Find all schedulers compatible with the given fusion and inputs.

Parameters
----------
fusion : Fusion
The fusion to check.
inputs : iterable
The input tensors/values for the fusion.

Returns
-------
list of SchedulerType
A list of scheduler types that can schedule the fusion.
)");

schedule.def(
"schedule",
[](Fusion* fusion,
SchedulerType scheduler_type,
const py::iterable& inputs) {
auto args = from_pyiterable(inputs);
return SchedulerEntry::scheduleWith(
fusion, scheduler_type, args, /*validate_scheduler=*/true);
},
py::arg("fusion"),
py::arg("scheduler_type"),
py::arg("inputs"),
R"(
Schedule the fusion with the specified scheduler type.

Parameters
----------
fusion : Fusion
The fusion to schedule.
scheduler_type : SchedulerType
The type of scheduler to use.
inputs : iterable
The input tensors/values for the fusion.

Returns
-------
HeuristicParams
The heuristics for the scheduled fusion.

Notes
-----
This function will raise an error if the scheduler cannot schedule the fusion.
)");
}

} // namespace
Expand Down
Loading
Loading