Skip to content

[SYCL][NFC] Extract specialization constant's processing from sycl-post-link to SYCLPostLink library. #19022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions llvm/include/llvm/SYCLPostLink/SpecializationConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//= SpecializationConstants.h - Processing of SYCL Specialization Constants ==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Specialization constants processing consists of lowering and generation
// of new module with spec consts replaced by default values.
//===----------------------------------------------------------------------===//

#ifndef LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H
#define LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H

#include "llvm/ADT/SmallVector.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"

#include <optional>

namespace llvm {
namespace sycl {

/// Handling consists of SpecConsts's lowering depending on the given
/// \p Mode. If \p Mode is std::nullopt, then no lowering happens.
/// If \p GenerateModuleDescWithDefaultSpecConsts is true, then a generation
/// of new modules with specialization constants replaced by default values
/// happens and the result is written in \p NewModuleDescs.
/// Otherwise, \p NewModuleDescs is expected to be nullptr.
///
/// Returns boolean value indicating whether the lowering has changed the input
/// modules.
bool handleSpecializationConstants(
llvm::SmallVectorImpl<module_split::ModuleDesc> &MDs,
std::optional<SpecConstantsPass::HandlingMode> Mode,
bool GenerateModuleDescWithDefaultSpecConsts = false,
llvm::SmallVectorImpl<module_split::ModuleDesc> *NewModuleDescs = nullptr);

} // namespace sycl
} // namespace llvm

#endif // LLVM_SYCL_POST_LINK_SPECIALIZATION_CONSTANTS_H
1 change: 1 addition & 0 deletions llvm/lib/SYCLPostLink/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_llvm_component_library(LLVMSYCLPostLink
ComputeModuleRuntimeInfo.cpp
ESIMDPostSplitProcessing.cpp
ModuleSplitter.cpp
SpecializationConstants.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLPostLink
Expand Down
97 changes: 97 additions & 0 deletions llvm/lib/SYCLPostLink/SpecializationConstants.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
//= SpecializationConstants.h - Processing of SYCL Specialization Constants ==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// See comments in the header.
//===----------------------------------------------------------------------===//

#include "llvm/SYCLPostLink/SpecializationConstants.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"

#include <optional>

using namespace llvm;

namespace {

bool lowerSpecConstants(module_split::ModuleDesc &MD,
SpecConstantsPass::HandlingMode Mode) {
ModulePassManager RunSpecConst;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(Mode);
// Register required analysis
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
RunSpecConst.addPass(std::move(SCP));

// Perform the spec constant intrinsics transformation on resulting module
PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM);
MD.Props.SpecConstsMet = !Res.areAllPreserved();
return MD.Props.SpecConstsMet;
}

/// Function generates the copy of the given \p MD where all uses of
/// Specialization Constants are replaced by corresponding default values.
/// If the Module in \p MD doesn't contain specialization constants then
/// std::nullopt is returned.
std::optional<module_split::ModuleDesc>
cloneModuleWithSpecConstsReplacedByDefaultValues(
const module_split::ModuleDesc &MD) {
std::optional<module_split::ModuleDesc> NewMD;
if (!checkModuleContainsSpecConsts(MD.getModule()))
return NewMD;

NewMD = MD.clone();
NewMD->setSpecConstantDefault(true);

ModulePassManager MPM;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values);
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MPM.addPass(std::move(SCP));
MPM.addPass(StripDeadPrototypesPass());

PreservedAnalyses Res = MPM.run(NewMD->getModule(), MAM);
NewMD->Props.SpecConstsMet = !Res.areAllPreserved();
assert(NewMD->Props.SpecConstsMet &&
"This property should be true since the presence of SpecConsts "
"has been checked before the run of the pass");
NewMD->rebuildEntryPoints();
return NewMD;
}

} // namespace

bool llvm::sycl::handleSpecializationConstants(
SmallVectorImpl<module_split::ModuleDesc> &MDs,
std::optional<SpecConstantsPass::HandlingMode> Mode,
bool GenerateModuleDescWithDefaultSpecConsts,
SmallVectorImpl<module_split::ModuleDesc> *NewModuleDescs) {
assert((GenerateModuleDescWithDefaultSpecConsts ^ !NewModuleDescs) &&
"NewModuleDescs pointer is nullptr iff "
"GenerateModuleDescWithDefaultSpecConsts is false.");

bool Modified = false;
for (module_split::ModuleDesc &MD : MDs) {
if (GenerateModuleDescWithDefaultSpecConsts)
if (std::optional<module_split::ModuleDesc> NewMD =
cloneModuleWithSpecConstsReplacedByDefaultValues(MD);
NewMD)
NewModuleDescs->push_back(std::move(*NewMD));

if (Mode)
Modified |= lowerSpecConstants(MD, *Mode);
}

return Modified;
}
72 changes: 11 additions & 61 deletions llvm/tools/sycl-post-link/sycl-post-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/SYCLPostLink/ComputeModuleRuntimeInfo.h"
#include "llvm/SYCLPostLink/ESIMDPostSplitProcessing.h"
#include "llvm/SYCLPostLink/ModuleSplitter.h"
#include "llvm/SYCLPostLink/SpecializationConstants.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/InitLLVM.h"
Expand All @@ -48,7 +49,6 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SystemUtils.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"

#include <algorithm>
#include <memory>
Expand Down Expand Up @@ -448,56 +448,6 @@ module_split::ModuleDesc link(module_split::ModuleDesc &&MD1,
return Res;
}

bool processSpecConstants(module_split::ModuleDesc &MD) {
MD.Props.SpecConstsMet = false;

if (SpecConstLower.getNumOccurrences() == 0)
return false;

ModulePassManager RunSpecConst;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstLower == SC_NATIVE_MODE
? SpecConstantsPass::HandlingMode::native
: SpecConstantsPass::HandlingMode::emulation);
// Register required analysis
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
RunSpecConst.addPass(std::move(SCP));

// Perform the spec constant intrinsics transformation on resulting module
PreservedAnalyses Res = RunSpecConst.run(MD.getModule(), MAM);
MD.Props.SpecConstsMet = !Res.areAllPreserved();
return MD.Props.SpecConstsMet;
}

/// Function generates the copy of the given ModuleDesc where all uses of
/// Specialization Constants are replaced by corresponding default values.
/// If the Module in MD doesn't contain specialization constants then
/// std::nullopt is returned.
std::optional<module_split::ModuleDesc>
processSpecConstantsWithDefaultValues(const module_split::ModuleDesc &MD) {
std::optional<module_split::ModuleDesc> NewModuleDesc;
if (!checkModuleContainsSpecConsts(MD.getModule()))
return NewModuleDesc;

NewModuleDesc = MD.clone();
NewModuleDesc->setSpecConstantDefault(true);

ModulePassManager MPM;
ModuleAnalysisManager MAM;
SpecConstantsPass SCP(SpecConstantsPass::HandlingMode::default_values);
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MPM.addPass(std::move(SCP));
MPM.addPass(StripDeadPrototypesPass());

PreservedAnalyses Res = MPM.run(NewModuleDesc->getModule(), MAM);
NewModuleDesc->Props.SpecConstsMet = !Res.areAllPreserved();
assert(NewModuleDesc->Props.SpecConstsMet &&
"This property should be true since the presence of SpecConsts "
"has been checked before the run of the pass");
NewModuleDesc->rebuildEntryPoints();
return NewModuleDesc;
}

constexpr int MAX_COLUMNS_IN_FILE_TABLE = 3;

void addTableRow(util::SimpleTable &Table,
Expand Down Expand Up @@ -679,6 +629,12 @@ processInputModule(std::unique_ptr<Module> M) {
error(toString(std::move(E)));
}

std::optional<SpecConstantsPass::HandlingMode> SCMode;
if (SpecConstLower.getNumOccurrences() > 0)
SCMode = SpecConstLower == SC_NATIVE_MODE
? SpecConstantsPass::HandlingMode::native
: SpecConstantsPass::HandlingMode::emulation;

// It is important that we *DO NOT* preserve all the splits in memory at the
// same time, because it leads to a huge RAM consumption by the tool on bigger
// inputs.
Expand All @@ -693,16 +649,10 @@ processInputModule(std::unique_ptr<Module> M) {
assert(MMs.size() && "at least one module is expected after ESIMD split");

SmallVector<module_split::ModuleDesc, 2> MMsWithDefaultSpecConsts;
for (size_t I = 0; I != MMs.size(); ++I) {
if (GenerateDeviceImageWithDefaultSpecConsts) {
std::optional<module_split::ModuleDesc> NewMD =
processSpecConstantsWithDefaultValues(MMs[I]);
if (NewMD)
MMsWithDefaultSpecConsts.push_back(std::move(*NewMD));
}

Modified |= processSpecConstants(MMs[I]);
}
Modified |= handleSpecializationConstants(
MMs, SCMode, GenerateDeviceImageWithDefaultSpecConsts,
GenerateDeviceImageWithDefaultSpecConsts ? &MMsWithDefaultSpecConsts
: nullptr);

if (IROutputOnly) {
if (SplitOccurred) {
Expand Down
Loading