diff --git a/llvm/include/llvm/SYCLPostLink/SpecializationConstants.h b/llvm/include/llvm/SYCLPostLink/SpecializationConstants.h new file mode 100644 index 000000000000..34d3e3bd8b48 --- /dev/null +++ b/llvm/include/llvm/SYCLPostLink/SpecializationConstants.h @@ -0,0 +1,42 @@ +//= SpecializationConstants.h - Processing of SYCL Specialization Constants ==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Specialization constants processing consists of lowering and generation +// of new module with spec consts replaced by default values. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H +#define LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/SYCLLowerIR/SpecConstants.h" +#include "llvm/SYCLPostLink/ModuleSplitter.h" + +#include + +namespace llvm { +namespace sycl { + +/// Handling consists of SpecConsts's lowering depending on the given +/// \p Mode. If \p Mode is std::nullopt, then no lowering happens. +/// If \p GenerateModuleDescWithDefaultSpecConsts is true, then a generation +/// of new modules with specialization constants replaced by default values +/// happens and the result is written in \p NewModuleDescs. +/// Otherwise, \p NewModuleDescs is expected to be nullptr. +/// +/// Returns boolean value indicating whether the lowering has changed the input +/// modules. +bool handleSpecializationConstants( + llvm::SmallVectorImpl &MDs, + std::optional Mode, + bool GenerateModuleDescWithDefaultSpecConsts = false, + llvm::SmallVectorImpl *NewModuleDescs = nullptr); + +} // namespace sycl +} // namespace llvm + +#endif // LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H diff --git a/llvm/lib/SYCLPostLink/CMakeLists.txt b/llvm/lib/SYCLPostLink/CMakeLists.txt index ffe061acf1b6..d2b1b85c3561 100644 --- a/llvm/lib/SYCLPostLink/CMakeLists.txt +++ b/llvm/lib/SYCLPostLink/CMakeLists.txt @@ -2,6 +2,7 @@ add_llvm_component_library(LLVMSYCLPostLink ComputeModuleRuntimeInfo.cpp ESIMDPostSplitProcessing.cpp ModuleSplitter.cpp + SpecializationConstants.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLPostLink diff --git a/llvm/lib/SYCLPostLink/SpecializationConstants.cpp b/llvm/lib/SYCLPostLink/SpecializationConstants.cpp new file mode 100644 index 000000000000..db2649694fc1 --- /dev/null +++ b/llvm/lib/SYCLPostLink/SpecializationConstants.cpp @@ -0,0 +1,97 @@ +//= SpecializationConstants.h - Processing of SYCL Specialization Constants ==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// See comments in the header. +//===----------------------------------------------------------------------===// + +#include "llvm/SYCLPostLink/SpecializationConstants.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/SYCLLowerIR/SpecConstants.h" +#include "llvm/SYCLPostLink/ModuleSplitter.h" +#include "llvm/Transforms/IPO/StripDeadPrototypes.h" + +#include + +using namespace llvm; + +namespace { + +bool lowerSpecConstants(module_split::ModuleDesc &MD, + SpecConstantsPass::HandlingMode Mode) { + ModulePassManager RunSpecConst; + ModuleAnalysisManager MAM; + SpecConstantsPass SCP(Mode); + // Register required analysis + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + RunSpecConst.addPass(std::move(SCP)); + + // Perform the spec constant intrinsics transformation on resulting module + PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM); + MD.Props.SpecConstsMet = !Res.areAllPreserved(); + return MD.Props.SpecConstsMet; +} + +/// Function generates the copy of the given \p MD where all uses of +/// Specialization Constants are replaced by corresponding default values. +/// If the Module in \p MD doesn't contain specialization constants then +/// std::nullopt is returned. +std::optional +cloneModuleWithSpecConstsReplacedByDefaultValues( + const module_split::ModuleDesc &MD) { + std::optional NewMD; + if (!checkModuleContainsSpecConsts(MD.getModule())) + return NewMD; + + NewMD = MD.clone(); + NewMD->setSpecConstantDefault(true); + + ModulePassManager MPM; + ModuleAnalysisManager MAM; + SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values); + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + MPM.addPass(std::move(SCP)); + MPM.addPass(StripDeadPrototypesPass()); + + PreservedAnalyses Res = MPM.run(NewMD->getModule(), MAM); + NewMD->Props.SpecConstsMet = !Res.areAllPreserved(); + assert(NewMD->Props.SpecConstsMet && + "This property should be true since the presence of SpecConsts " + "has been checked before the run of the pass"); + NewMD->rebuildEntryPoints(); + return NewMD; +} + +} // namespace + +bool llvm::sycl::handleSpecializationConstants( + SmallVectorImpl &MDs, + std::optional Mode, + bool GenerateModuleDescWithDefaultSpecConsts, + SmallVectorImpl *NewModuleDescs) { + assert((GenerateModuleDescWithDefaultSpecConsts ^ !NewModuleDescs) && + "NewModuleDescs pointer is nullptr iff " + "GenerateModuleDescWithDefaultSpecConsts is false."); + + bool Modified = false; + for (module_split::ModuleDesc &MD : MDs) { + if (GenerateModuleDescWithDefaultSpecConsts) + if (std::optional NewMD = + cloneModuleWithSpecConstsReplacedByDefaultValues(MD); + NewMD) + NewModuleDescs->push_back(std::move(*NewMD)); + + if (Mode) + Modified |= lowerSpecConstants(MD, *Mode); + } + + return Modified; +} diff --git a/llvm/tools/sycl-post-link/sycl-post-link.cpp b/llvm/tools/sycl-post-link/sycl-post-link.cpp index 5342cd7d5175..75ca2bb5b763 100644 --- a/llvm/tools/sycl-post-link/sycl-post-link.cpp +++ b/llvm/tools/sycl-post-link/sycl-post-link.cpp @@ -40,6 +40,7 @@ #include "llvm/SYCLPostLink/ComputeModuleRuntimeInfo.h" #include "llvm/SYCLPostLink/ESIMDPostSplitProcessing.h" #include "llvm/SYCLPostLink/ModuleSplitter.h" +#include "llvm/SYCLPostLink/SpecializationConstants.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/InitLLVM.h" @@ -48,7 +49,6 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/SystemUtils.h" #include "llvm/Support/WithColor.h" -#include "llvm/Transforms/IPO/StripDeadPrototypes.h" #include #include @@ -448,56 +448,6 @@ module_split::ModuleDesc link(module_split::ModuleDesc &&MD1, return Res; } -bool processSpecConstants(module_split::ModuleDesc &MD) { - MD.Props.SpecConstsMet = false; - - if (SpecConstLower.getNumOccurrences() == 0) - return false; - - ModulePassManager RunSpecConst; - ModuleAnalysisManager MAM; - SpecConstantsPass SCP(SpecConstLower == SC_NATIVE_MODE - ? SpecConstantsPass::HandlingMode::native - : SpecConstantsPass::HandlingMode::emulation); - // Register required analysis - MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); - RunSpecConst.addPass(std::move(SCP)); - - // Perform the spec constant intrinsics transformation on resulting module - PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM); - MD.Props.SpecConstsMet = !Res.areAllPreserved(); - return MD.Props.SpecConstsMet; -} - -/// Function generates the copy of the given ModuleDesc where all uses of -/// Specialization Constants are replaced by corresponding default values. -/// If the Module in MD doesn't contain specialization constants then -/// std::nullopt is returned. -std::optional -processSpecConstantsWithDefaultValues(const module_split::ModuleDesc &MD) { - std::optional NewModuleDesc; - if (!checkModuleContainsSpecConsts(MD.getModule())) - return NewModuleDesc; - - NewModuleDesc = MD.clone(); - NewModuleDesc->setSpecConstantDefault(true); - - ModulePassManager MPM; - ModuleAnalysisManager MAM; - SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values); - MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); - MPM.addPass(std::move(SCP)); - MPM.addPass(StripDeadPrototypesPass()); - - PreservedAnalyses Res = MPM.run(NewModuleDesc->getModule(), MAM); - NewModuleDesc->Props.SpecConstsMet = !Res.areAllPreserved(); - assert(NewModuleDesc->Props.SpecConstsMet && - "This property should be true since the presence of SpecConsts " - "has been checked before the run of the pass"); - NewModuleDesc->rebuildEntryPoints(); - return NewModuleDesc; -} - constexpr int MAX_COLUMNS_IN_FILE_TABLE = 3; void addTableRow(util::SimpleTable &Table, @@ -679,6 +629,12 @@ processInputModule(std::unique_ptr M) { error(toString(std::move(E))); } + std::optional SCMode; + if (SpecConstLower.getNumOccurrences() > 0) + SCMode = SpecConstLower == SC_NATIVE_MODE + ? SpecConstantsPass::HandlingMode::native + : SpecConstantsPass::HandlingMode::emulation; + // It is important that we *DO NOT* preserve all the splits in memory at the // same time, because it leads to a huge RAM consumption by the tool on bigger // inputs. @@ -693,16 +649,10 @@ processInputModule(std::unique_ptr M) { assert(MMs.size() && "at least one module is expected after ESIMD split"); SmallVector MMsWithDefaultSpecConsts; - for (size_t I = 0; I != MMs.size(); ++I) { - if (GenerateDeviceImageWithDefaultSpecConsts) { - std::optional NewMD = - processSpecConstantsWithDefaultValues(MMs[I]); - if (NewMD) - MMsWithDefaultSpecConsts.push_back(std::move(*NewMD)); - } - - Modified |= processSpecConstants(MMs[I]); - } + Modified |= handleSpecializationConstants( + MMs, SCMode, GenerateDeviceImageWithDefaultSpecConsts, + GenerateDeviceImageWithDefaultSpecConsts ? &MMsWithDefaultSpecConsts + : nullptr); if (IROutputOnly) { if (SplitOccurred) {