diff --git a/csrc/runtime/executor.h b/csrc/runtime/executor.h index e8e33dd8160..d4491f80e4f 100644 --- a/csrc/runtime/executor.h +++ b/csrc/runtime/executor.h @@ -105,7 +105,7 @@ class KernelExecutor : public ExecutorAbstract { const KernelArgumentHolder& args = {}, const LaunchParams& launch_constraints = LaunchParams(), CompileParams compile_params = CompileParams(), - SchedulerType sceduler_type = SchedulerType::None); + SchedulerType scheduler_type = SchedulerType::None); NVF_API KernelArgumentHolder run(KernelArgumentHolder args, diff --git a/csrc/scheduler/heuristic.h b/csrc/scheduler/heuristic.h index 8eb44041fc4..12b360bed9e 100644 --- a/csrc/scheduler/heuristic.h +++ b/csrc/scheduler/heuristic.h @@ -21,7 +21,7 @@ class HeuristicDataCache; // Top-level class representing heuristic parameters. Most schedulers // have their own subclasses to have their specific parameters, except // for ExprEval schedulers. -class HeuristicParams : public PolymorphicBase { +class NVF_API HeuristicParams : public PolymorphicBase { public: std::string tag = ""; diff --git a/csrc/scheduler/registry.h b/csrc/scheduler/registry.h index 7ceb628247e..c5765b95f3f 100644 --- a/csrc/scheduler/registry.h +++ b/csrc/scheduler/registry.h @@ -96,7 +96,7 @@ namespace Schedule { //! External access for canSchedule utilities through SchedulerEntry //! to avoid exposing a single function to the namespace -bool canSchedule( +NVF_API bool canSchedule( SchedulerType sh, Fusion* fusion, SchedulerRuntimeInfo& runtime_info, diff --git a/csrc/scheduler/runtime_info.h b/csrc/scheduler/runtime_info.h index 91aa4618727..28cde3816a0 100644 --- a/csrc/scheduler/runtime_info.h +++ b/csrc/scheduler/runtime_info.h @@ -41,7 +41,7 @@ class SchedulerRuntimeInfo : public NonCopyable { //! The index type of forced_index_type is used if given, no matter //! how large the actual arguments and fusion tensors //! are. CORRECTNESS IS NOT GUARANTEED. - SchedulerRuntimeInfo( + NVF_API SchedulerRuntimeInfo( Fusion* complete_fusion, KernelArgumentHolder args, PrecomputedValues* precomputed_values = nullptr, diff --git a/python/nvfuser_direct/__init__.py b/python/nvfuser_direct/__init__.py index 9f6c3b2222a..f23804fc84c 100644 --- a/python/nvfuser_direct/__init__.py +++ b/python/nvfuser_direct/__init__.py @@ -322,7 +322,9 @@ def execute( _disable_options=_disable_options, ) - def manual_execute(self, inputs): + def manual_execute( + self, inputs, heuristic_params: Optional[HeuristicParams] = None + ): """ Execute the fusion with the given inputs. @@ -344,8 +346,21 @@ def manual_execute(self, inputs): self.ke = KernelExecutor() if not self.ke.is_compiled(): - self.ke.compile(self.fusion, inputs) + if heuristic_params is not None: + self.ke.compile( + self.fusion, + inputs, + heuristic_params.lparams, + heuristic_params.cparams, + heuristic_params.scheduler_type, + ) + else: + self.ke.compile(self.fusion, inputs) + if heuristic_params is not None: + return self.ke.run( + inputs, heuristic_params.lparams, heuristic_params.cparams + ) return self.ke.run(inputs) def last_repro_script(self) -> str: diff --git a/python/python_direct/heuristic_params.cpp b/python/python_direct/heuristic_params.cpp index 11e1d8b0115..4de432384f2 100644 --- a/python/python_direct/heuristic_params.cpp +++ b/python/python_direct/heuristic_params.cpp @@ -152,6 +152,39 @@ void bindHeuristicParams(py::module& nvfuser) { "include_paths", &CompileParams::include_paths, R"( The additional include paths to use for the kernel. )"); + + py::class_ heuristic_parameters( + nvfuser, "HeuristicParams", py::module_local()); + heuristic_parameters.def( + "__repr__", [](const HeuristicParams& self) { return self.toString(); }); + heuristic_parameters.def("__eq__", &HeuristicParams::sameAs, R"( + Whether the heuristic parameters are the same. + )"); + heuristic_parameters.def_readwrite("lparams", &HeuristicParams::lparams, R"( + The launch parameters for the kernel. + )"); + heuristic_parameters.def_readwrite("cparams", &HeuristicParams::cparams, R"( + The compile parameters for the kernel. + )"); + heuristic_parameters.def_readonly( + "scheduler_type", &HeuristicParams::scheduler_type, R"( + The type of scheduler that generated these parameters. + )"); + heuristic_parameters.def("hash", &HeuristicParams::hash, R"( + The hash of the heuristic parameters. + )"); + + py::class_ pointwise( + nvfuser, "PointwiseParams", py::module_local()); + pointwise.def(py::init()); + pointwise.def( + "__repr__", [](const PointwiseParams& self) { return self.toString(); }); + + py::class_ reduction( + nvfuser, "ReductionParams", py::module_local()); + reduction.def(py::init()); + reduction.def( + "__repr__", [](const ReductionParams& self) { return self.toString(); }); } } // namespace nvfuser::python diff --git a/python/python_direct/schedule.cpp b/python/python_direct/schedule.cpp index b1ea10be2d4..3f6bf6fe702 100644 --- a/python/python_direct/schedule.cpp +++ b/python/python_direct/schedule.cpp @@ -6,6 +6,11 @@ */ // clang-format on #include +#include +#include +#include +#include +#include #include #include #include @@ -153,6 +158,122 @@ void bindTensorviewScheduleOps(py::module_& schedule) { None )", py::arg("selected_tensors") = std::vector()); + + schedule.def( + "can_schedule", + [](Fusion* fusion, + SchedulerType scheduler_type, + const py::iterable& inputs) { + // Enable collection of messages from canScheduleRejectReason + DebugDumpOptionsGuard debug_dump_options_guard; + DebugDumpOptionsGuard::getCurOptions().set( + DebugDumpOption::FusionSegmenterLog); + + // Send debug messages to stringstream + std::stringstream ss; + DebugStreamGuard dsg(ss); + + // Create runtime info from inputs + auto args = from_pyiterable(inputs); + SchedulerRuntimeInfo runtime_info(fusion, args); + + bool can_schedule = + Schedule::canSchedule(scheduler_type, fusion, runtime_info); + return std::make_tuple(can_schedule, ss.str()); + }, + py::arg("fusion"), + py::arg("scheduler_type"), + py::arg("inputs"), + R"( + Check if a scheduler can schedule the given fusion with the provided inputs. + + Parameters + ---------- + fusion : Fusion + The fusion to check. + scheduler_type : SchedulerType + The type of scheduler to check. + inputs : iterable + The input tensors/values for the fusion. + + Returns + ------- + tuple of (bool, str) + A tuple containing: + - bool: True if the scheduler can schedule the fusion, False otherwise. + - str: Debug message explaining why the scheduler was accepted or rejected. + )"); + + schedule.def( + "find_compatible_schedulers", + [](Fusion* fusion, const py::iterable& inputs) { + // Create runtime info from inputs + auto args = from_pyiterable(inputs); + SchedulerRuntimeInfo runtime_info(fusion, args); + + std::vector compatible_schedulers; + + // Check all scheduler types except None + for (const auto& scheduler_type : all_heuristics_in_priority_order) { + if (scheduler_type != SchedulerType::None && + Schedule::canSchedule(scheduler_type, fusion, runtime_info)) { + compatible_schedulers.push_back(scheduler_type); + } + } + + return compatible_schedulers; + }, + py::arg("fusion"), + py::arg("inputs"), + R"( + Find all schedulers compatible with the given fusion and inputs. + + Parameters + ---------- + fusion : Fusion + The fusion to check. + inputs : iterable + The input tensors/values for the fusion. + + Returns + ------- + list of SchedulerType + A list of scheduler types that can schedule the fusion. + )"); + + schedule.def( + "schedule", + [](Fusion* fusion, + SchedulerType scheduler_type, + const py::iterable& inputs) { + auto args = from_pyiterable(inputs); + return SchedulerEntry::scheduleWith( + fusion, scheduler_type, args, /*validate_scheduler=*/true); + }, + py::arg("fusion"), + py::arg("scheduler_type"), + py::arg("inputs"), + R"( + Schedule the fusion with the specified scheduler type. + + Parameters + ---------- + fusion : Fusion + The fusion to schedule. + scheduler_type : SchedulerType + The type of scheduler to use. + inputs : iterable + The input tensors/values for the fusion. + + Returns + ------- + HeuristicParams + The heuristics for the scheduled fusion. + + Notes + ----- + This function will raise an error if the scheduler cannot schedule the fusion. + )"); } } // namespace diff --git a/tests/python/direct/test_tutorial.py b/tests/python/direct/test_tutorial.py index b6b49ea12fc..bc7e7d62f31 100644 --- a/tests/python/direct/test_tutorial.py +++ b/tests/python/direct/test_tutorial.py @@ -20,8 +20,9 @@ DataType, CompileParams, KernelExecutor, + SchedulerType, ) -from nvfuser_direct import idm +from nvfuser_direct import idm, schedule from python.direct_utils import ( is_pre_hopper, @@ -30,6 +31,56 @@ verbose_ = True +# A helper function to test heuristic schedulers with automatic scheduling +def check_auto_schedule(schedule_fn): + """ + A decorator to validate a schedule_fn before applying it to a fusion. + + Args: + schedule_fn: The function to apply the scheduler + """ + # List of all scheduler heuristics for testing + # NOTE We cannot iterate pybind11 enum directly, so we extract the entries here. + all_scheduler_heuristics = [ + heuristic + for heuristic, _ in SchedulerType.__entries.values() + if not SchedulerType.none + ] + + def inner_fn(fusion, selected_heuristic, inputs): + """ + Helper function to validate a schedule_fn. + + Args: + fusion: The Fusion object to schedule + selected_heuristic: The SchedulerType expected to work + inputs: Input tensors for the fusion + """ + available_heuristics = schedule.find_compatible_schedulers(fusion, inputs) + + # Assume that only a single heuristic is available for fusion + assert len(available_heuristics) == 1 + + # Check that only selected heuristic is available as a scheduler + assert set(available_heuristics) == set([selected_heuristic]) + + # Double-check with can_schedule + status, _ = schedule.can_schedule(fusion, selected_heuristic, inputs) + assert status + + # Check that the other schedulers are not compatible with this fusion + assert all( + [ + not schedule.can_schedule(fusion, h, inputs)[0] + for h in all_scheduler_heuristics + if h is not selected_heuristic + ] + ) + return schedule_fn(fusion, selected_heuristic, inputs) + + return inner_fn + + def test_tutorial_memcpy(): # First, we define a fusion. A common pattern is: # - Declare a Fusion, which works as a container of expressions using @@ -1434,3 +1485,103 @@ def test_tutorial_tma_bank_conflict_free_transpose(nvfuser_direct_test): ke.compile(fd.fusion, [t0], compile_params=index32bit) outputs = ke.run([t0]) assert outputs[0].equal(t0.t()) + + +def test_tutorial_pointwise_auto_scheduler(): + """ + Implement a simple pointwise kernel with automatic scheduling. + Uses nvfuser's PointwiseScheduler. + """ + inputs = [ + torch.randn(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ] + + with FusionDefinition() as fd: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + t2 = fd.ops.add(t0, t1) + t3 = fd.ops.exp(t2) + fd.add_output(t3) + + # Apply selected scheduler + heuristic_params = check_auto_schedule(schedule.schedule)( + fd.fusion, SchedulerType.pointwise, inputs + ) + + nvf_out = fd.manual_execute(inputs, heuristic_params) + eager_out = torch.exp(inputs[0] + inputs[1]) + torch.testing.assert_close(eager_out, nvf_out[0]) + + +def test_tutorial_reduction_auto_scheduler(): + """ + Implement a simple reduction kernel with automatic scheduling. + - Expects failure with PointwiseScheduler + - Uses nvfuser's ReductionScheduler + """ + inputs = [ + torch.randn(4, 4, device="cuda"), + ] + + with FusionDefinition() as fd: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.ops.sum(t0, dims=[1]) + t2 = fd.ops.exp(t1) + fd.add_output(t2) + + # Test error msg for can_schedule + pointwise_status, error_msg = schedule.can_schedule( + fd.fusion, SchedulerType.pointwise, inputs + ) + assert not pointwise_status + assert ( + error_msg.strip() + == "Scheduler _pointwise_ ***rejected*** because : cannot find reference tensor" + ) + + # Apply selected scheduler + heuristic_params = check_auto_schedule(schedule.schedule)( + fd.fusion, SchedulerType.reduction, inputs + ) + + nvf_out = fd.manual_execute(inputs, heuristic_params) + eager_out = torch.exp(inputs[0].sum(1)) + torch.testing.assert_close(eager_out, nvf_out[0]) + + +def test_tutorial_inner_persistent_auto_scheduler(): + """ + Implement a simple normalization kernel with automatic scheduling. + Uses nvfuser's InnerPersistentScheduler. + """ + tensor_size = 4 + inputs = [torch.randn(tensor_size, tensor_size, device="cuda")] + + with FusionDefinition() as fd: + t0 = fd.from_pytorch(inputs[0]) + s0 = fd.define_scalar(1e-6, dtype=DataType.Double) + norm_const = fd.define_scalar(tensor_size, dtype=DataType.Int) + + bcast_sum0 = fd.ops.sum(t0, dims=[-1], keepdim=True) + mean = fd.ops.div(bcast_sum0, norm_const) + + diff = fd.ops.sub(t0, mean) + diff_sq = fd.ops.mul(diff, diff) + bcast_sum1 = fd.ops.sum(diff_sq, dims=[-1], keepdim=True) + var = fd.ops.div(bcast_sum1, norm_const) + + t0_diff = fd.ops.sub(t0, mean) + var_eps = fd.ops.sqrt(fd.ops.add(var, s0)) + t0_norm = fd.ops.div(t0_diff, var_eps) + fd.add_output(t0_norm) + + # Apply selected scheduler + heuristic_params = check_auto_schedule(schedule.schedule)( + fd.fusion, SchedulerType.inner_persistent, inputs + ) + + nvf_out = fd.manual_execute(inputs, heuristic_params) + var, mean = torch.var_mean(inputs[0], dim=-1, correction=0, keepdim=True) + eager_out = (inputs[0] - mean) / torch.sqrt(var + 1e-6) + torch.testing.assert_close(eager_out, nvf_out[0])