diff --git a/clang/lib/DPCT/AnalysisInfo.cpp b/clang/lib/DPCT/AnalysisInfo.cpp index 36cd3cfe5405..b691a6b1935e 100644 --- a/clang/lib/DPCT/AnalysisInfo.cpp +++ b/clang/lib/DPCT/AnalysisInfo.cpp @@ -3130,7 +3130,7 @@ MemVarInfo::MemVarInfo(unsigned Offset, auto DS1 = getParentDeclStmt(Var); auto DS2 = getParentDeclStmt(DeclOfVarType); if (DS1 && DS2 && DS1 == DS2) { - IsAnonymousType = true; + IsAnonymousType = !DeclOfVarType->hasNameForLinkage(); DeclStmtOfVarType = DS2; const auto LocInfo = DpctGlobalInfo::getLocInfo( getDefinitionRange(DS2->getBeginLoc(), DS2->getEndLoc()) @@ -3195,7 +3195,9 @@ std::string MemVarInfo::getDeclarationReplacement(const VarDecl *VD) { OS << "auto &" << getName() << " = " << "*" << MapNames::getClNamespace() << "ext::oneapi::group_local_memory_for_overwrite<" - << getType()->getBaseName(); + << ((isAnonymousType() && isShared() && isLocal()) + ? LocalTypeName + : getType()->getBaseName()); for (auto &ArraySize : getType()->getRange()) { OS << "[" << ArraySize.getSize() << "]"; } diff --git a/clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc b/clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc index 45757ac75b06..118a1face6d8 100644 --- a/clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc +++ b/clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc @@ -116,9 +116,28 @@ TYPE_REWRITE_ENTRY( WARNING_FACTORY(Diagnostics::UNSUPPORT_SYCLCOMPAT, TYPESTR), HEADER_INSERTION_FACTORY( HeaderType::HT_DPCT_GROUP_Utils, - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::group_radix_sort"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2))))) + TYPE_CONDITIONAL_FACTORY( + UseGroupLocalMemory(), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(9, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_radix_sort"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(4)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(10, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_radix_sort"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(4), TEMPLATE_ARG(8)), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_radix_sort"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(4), TEMPLATE_ARG(8), + TEMPLATE_ARG(9)))), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_radix_sort"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))) // cub::BlockExchange TYPE_REWRITE_ENTRY( @@ -128,9 +147,28 @@ TYPE_REWRITE_ENTRY( WARNING_FACTORY(Diagnostics::UNSUPPORT_SYCLCOMPAT, TYPESTR), HEADER_INSERTION_FACTORY( HeaderType::HT_DPCT_GROUP_Utils, - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::exchange"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2))))) + TYPE_CONDITIONAL_FACTORY( + UseGroupLocalMemory(), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(5, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::exchange"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(1)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(6, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::exchange"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(1), TEMPLATE_ARG(4)), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::exchange"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(1), TEMPLATE_ARG(4), + TEMPLATE_ARG(5)))), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::exchange"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))) // cub::BlockShuffle TYPE_REWRITE_ENTRY( @@ -165,13 +203,45 @@ TYPE_REWRITE_ENTRY( HEADER_INSERTION_FACTORY( HeaderType::HT_DPCT_GROUP_Utils, TYPE_CONDITIONAL_FACTORY( - CheckTemplateArgCount(4), - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::group_load"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2), TEMPLATE_ARG(3)), - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::group_load"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))) + UseGroupLocalMemory(), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(4, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + STR(MapNames::getDpctNamespace() + + "group::group_load_algorithm::blocked"), + TEMPLATE_ARG(1)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(5, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(6, false, + std::less()), + TYPE_FACTORY( + STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1), + TEMPLATE_ARG(4)), + TYPE_FACTORY( + STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1), + TEMPLATE_ARG(4), TEMPLATE_ARG(5))))), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(4), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3)), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_load"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))) // cub::BlockStore TYPE_REWRITE_ENTRY( "cub::BlockStore", @@ -181,13 +251,45 @@ TYPE_REWRITE_ENTRY( HEADER_INSERTION_FACTORY( HeaderType::HT_DPCT_GROUP_Utils, TYPE_CONDITIONAL_FACTORY( - CheckTemplateArgCount(4), - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::group_store"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2), TEMPLATE_ARG(3)), - TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + - "group::group_store"), - TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))) + UseGroupLocalMemory(), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(4, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + STR(MapNames::getDpctNamespace() + + "group::group_store_algorithm::blocked"), + TEMPLATE_ARG(1)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(5, false, std::less()), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1)), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(6, false, + std::less()), + TYPE_FACTORY( + STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1), + TEMPLATE_ARG(4)), + TYPE_FACTORY( + STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(1), + TEMPLATE_ARG(4), TEMPLATE_ARG(5))))), + TYPE_CONDITIONAL_FACTORY( + CheckTemplateArgCount(4), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2), + TEMPLATE_ARG(3)), + TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() + + "group::group_store"), + TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))) FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, diff --git a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp index 24848fcfe4b7..017213274b8c 100644 --- a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp +++ b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp @@ -1152,9 +1152,75 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE, case TypeLoc::Typedef: case TypeLoc::Builtin: case TypeLoc::Using: + case TypeLoc::DependentName: case TypeLoc::Elaborated: case TypeLoc::Record: { TyName = DpctGlobalInfo::getTypeName(TL.getType()); + if (DpctGlobalInfo::useGroupLocalMemory() && + (TyName.find("TempStorage") != std::string::npos) && + isPreserveCubVar(TL.getType())) { + const RecordDecl *RD = nullptr; + const TemplateDecl *TD = nullptr; + const TypedefNameDecl *TND = nullptr; + if (auto ETL = TL.getAs()) { + if (auto RTL = ETL.getNamedTypeLoc().getAs()) { + RD = RTL.getDecl(); + } + } else if (auto RTL = TL.getAs()) { + RD = RTL.getDecl(); + } else if (auto TTL = TL.getAs()) { + TND = TTL.getTypedefNameDecl(); + } else if (auto DTL = TL.getAs()) { + const DependentNameType *DT = DTL.getTypePtr(); + auto *QNNS = DT->getQualifier(); + if (QNNS->getKind() == NestedNameSpecifier::TypeSpec) { + if (auto *SpecType = + dyn_cast(QNNS->getAsType())) { + TD = SpecType->getTemplateName().getAsTemplateDecl(); + } else if (auto *TT = dyn_cast(QNNS->getAsType())) { + if (auto D = TT->getDecl()) { + if (auto *SpecType = D->getUnderlyingType() + .getCanonicalType() + .getTypePtr() + ->getAs()) { + TD = SpecType->getTemplateName().getAsTemplateDecl(); + } + } + } + } + } + if (RD) { + auto DC = RD->getDeclContext(); + if (DC->getDeclKind() == Decl::Kind::ClassTemplateSpecialization) { + if (auto CTS = dyn_cast(DC)) { + if (dpct::DpctGlobalInfo::isInCudaPath( + CTS->getSpecializedTemplate()->getLocation())) { + addReplacement(TL.getBeginLoc(), TL.getEndLoc(), CSCE, + "TempLocalMemory"); + return; + } + } + } + } + if (TND) { + auto DC = TND->getDeclContext(); + if (DC && DC->isRecord()) { + auto *RD = dyn_cast(DC); + if (dpct::DpctGlobalInfo::isInCudaPath(RD->getLocation())) { + addReplacement(TL.getBeginLoc(), TL.getEndLoc(), CSCE, + "TempLocalMemory"); + return; + } + } + } + if (TD) { + if (dpct::DpctGlobalInfo::isInCudaPath(TD->getLocation())) { + addReplacement(TL.getAs().getNameLoc(), + TL.getEndLoc(), CSCE, "TempLocalMemory"); + return; + } + } + } RewriteType(TyName, TL); break; } diff --git a/clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp b/clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp index 65d51c7c5f27..67e40c50399b 100644 --- a/clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp +++ b/clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp @@ -16,6 +16,12 @@ inline auto UseSYCLCompat() { return [](const TypeLoc) -> bool { return DpctGlobalInfo::useSYCLCompat(); }; } +inline auto UseGroupLocalMemory() { + return [](const TypeLoc) -> bool { + return DpctGlobalInfo::useGroupLocalMemory(); + }; +} + TemplateArgumentInfo getTemplateArg(const TypeLoc &TL, unsigned Idx) { if (auto TSTL = TL.getAs()) { if (TSTL.getNumArgs() > Idx) { @@ -103,15 +109,18 @@ makeUserDefinedTypeStrCreator(MetaRuleObject &R, class CheckTemplateArgCount { unsigned Count; bool IsIncludeDefault; + std::function CmpFunc; public: - CheckTemplateArgCount(unsigned I, bool D = true) - : Count(I), IsIncludeDefault(D) {} + CheckTemplateArgCount( + unsigned I, bool D = true, + std::function F = std::equal_to()) + : Count(I), IsIncludeDefault(D), CmpFunc(F) {} bool operator()(const TypeLoc TL) { if (auto TSTL = TL.getAs()) { size_t Num = TSTL.getNumArgs(); if (IsIncludeDefault) { - return Num == Count; + return CmpFunc(Num, Count); } size_t NoneDefaultNum = 0; for (size_t i = 0; i < Num; i++) { @@ -119,7 +128,7 @@ class CheckTemplateArgCount { NoneDefaultNum++; } } - return NoneDefaultNum == Count; + return CmpFunc(NoneDefaultNum, Count); } return false; } diff --git a/clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp b/clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp index 01cb30192c66..cdf72859d353 100644 --- a/clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp +++ b/clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp @@ -680,6 +680,10 @@ void MemVarMigrationRule::processTypeDeclaredLocal( std::string Ret; llvm::raw_string_ostream OS(Ret); OS << getNL(DS->getEndLoc().isMacroID()) << getIndent(InsertSL, SM); + if (DpctGlobalInfo::useGroupLocalMemory()) { + OS << Info->getDeclarationReplacement(MemVar); + return OS.str(); + } OS << TypeName << ' '; if (IsReference) OS << '&'; @@ -719,8 +723,7 @@ void MemVarMigrationRule::processTypeDeclaredLocal( emplaceTransformation(new InsertText(InsertSL, GenDeclStmt(NewTypeName))); } else if (DS) { // remove var decl - emplaceTransformation(ReplaceVarDecl::getVarDeclReplacement( - MemVar, Info->getDeclarationReplacement(MemVar))); + emplaceTransformation(new ReplaceVarDecl(MemVar, "")); Info->setLocalTypeName(Info->getType()->getBaseName()); emplaceTransformation( @@ -731,7 +734,8 @@ void MemVarMigrationRule::processTypeDeclaredLocal( void MemVarMigrationRule::runRule( const ast_matchers::MatchFinder::MatchResult &Result) { if (auto MemVar = getAssistNodeAsType(Result, "var")) { - if (isCubVar(MemVar) || MemVar->hasAttr()) { + if ((isCubVar(MemVar) && !isPreserveCubVar(MemVar->getType())) || + MemVar->hasAttr()) { return; } std::string CanonicalType = @@ -787,7 +791,7 @@ void MemVarAnalysisRule::registerMatcher(MatchFinder &MF) { void MemVarAnalysisRule::runRule(const MatchFinder::MatchResult &Result) { if (auto MemVar = getAssistNodeAsType(Result, "var")) { - if (isCubVar(MemVar)) { + if (isCubVar(MemVar) && !isPreserveCubVar(MemVar->getType())) { return; } std::string CanonicalType = diff --git a/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp b/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp index 88acc8127fa3..2f2af00a9c99 100644 --- a/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp +++ b/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp @@ -100,7 +100,7 @@ void CubTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) { "cub::ArgIndexInputIterator", "cub::DiscardOutputIterator", "cub::DoubleBuffer", "cub::NullType", "cub::ArgMax", "cub::ArgMin", "cub::BlockRadixSort", "cub::BlockExchange", "cub::BlockLoad", - "cub::BlockStore", "cub::BlockShuffle"); + "cub::BlockStore", "cub::BlockShuffle", "TempStorage"); }; MF.addMatcher( @@ -258,6 +258,9 @@ void CubMemberCallRule::runRule( bool isBlockLoadStore = Name == "Load" || Name == "Store"; if (isBlockRadixSort || isBlockExchange || isBlockShuffle || isBlockLoadStore) { + if (DpctGlobalInfo::useGroupLocalMemory()) { + return; + } std::string HelpFuncName; if (isBlockRadixSort) HelpFuncName = "group_radix_sort"; @@ -802,6 +805,10 @@ void CubRule::registerMatcher(ast_matchers::MatchFinder &MF) { .bind("DeclStmt"), this); + MF.addMatcher(fieldDecl(hasType(hasCanonicalType(qualType(isTempStorage)))) + .bind("FieldTempStorage"), + this); + MF.addMatcher(cxxMemberCallExpr(has(memberExpr(member(hasAnyName( "InclusiveSum", "ExclusiveSum", "InclusiveScan", "ExclusiveScan", @@ -897,6 +904,19 @@ std::string CubRule::getOpRepl(const Expr *Operator) { } return OpRepl; } + +void CubRule::processFiledDecl(const FieldDecl *FD) { + if (!isPreserveCubVar(FD->getType())) { + auto P = FD->getParent(); + if (P->isUnion()) { + emplaceTransformation(new ReplaceText(FD->getSourceRange().getBegin(), + FD->getLocation(), "void *")); + } else { + emplaceTransformation(new ReplaceDecl(FD, "")); + } + } +} + void CubRule::processCubDeclStmt(const DeclStmt *DS) { std::string Repl; for (auto Decl : DS->decls()) { @@ -926,8 +946,9 @@ void CubRule::processCubDeclStmt(const DeclStmt *DS) { emplaceTransformation(new ReplaceDecl(RD, "")); } - // always remove TempStorage variable declaration - emplaceTransformation(new ReplaceStmt(DS, "")); + if (!isPreserveCubVar(VDecl->getType())) { + emplaceTransformation(new ReplaceStmt(DS, "")); + } // process TempStorage used in class constructor auto TempVarMatcher = compoundStmt(forEachDescendant( @@ -1718,8 +1739,11 @@ void CubRule::processDependentMemberCall( int CubRule::PlaceholderIndex = 1; void CubRule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) { - if (const CXXMemberCallExpr *MC = - getNodeAsType(Result, "MemberCall")) { + if (const FieldDecl *FD = + getNodeAsType(Result, "FieldTempStorage")) { + processFiledDecl(FD); + } else if (const CXXMemberCallExpr *MC = + getNodeAsType(Result, "MemberCall")) { processCubMemberCall(MC); } else if (const CXXDependentScopeMemberExpr *DMC = getNodeAsType( diff --git a/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.h b/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.h index 513f1f680a86..50caa8095e34 100644 --- a/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.h +++ b/clang/lib/DPCT/RulesLangLib/CUBAPIMigration.h @@ -113,7 +113,7 @@ class CubRule : public NamedMigrationRule { void processCubFuncCall(const CallExpr *CE, bool FuncCallUsed = false); void processCubMemberCall(const CXXMemberCallExpr *MC); void processTypeLoc(const TypeLoc *TL); - + void processFiledDecl(const FieldDecl *FD); void processThreadLevelFuncCall(const CallExpr *CE, bool FuncCallUsed); void processWarpLevelFuncCall(const CallExpr *CE, bool FuncCallUsed); void processBlockLevelMemberCall(const CXXMemberCallExpr *MC); diff --git a/clang/lib/DPCT/Utility.cpp b/clang/lib/DPCT/Utility.cpp index e9b2e30b6e17..0825c8d11d84 100644 --- a/clang/lib/DPCT/Utility.cpp +++ b/clang/lib/DPCT/Utility.cpp @@ -4047,6 +4047,24 @@ static bool isCubTempStorageType(const clang::Type *T) { return false; const clang::Type *DeclContextType = nullptr; + const clang::TypedefType *TT = dyn_cast(T); + if (!TT) { + if (auto ET = dyn_cast(T)) { + TT = dyn_cast(ET->desugar().getTypePtr()); + } + } + if (TT) { + auto D = TT->getDecl(); + if (D && (D->getNameAsString() == "TempStorage")) { + auto DC = D->getDeclContext(); + if (DC && DC->isRecord()) { + auto *RD = dyn_cast(DC); + DeclContextType = RD->getTypeForDecl(); + } else { + return false; + } + } + } // cub::{BlockReduce, BlockScan, WarpScan, ...}::TempStorage; if (auto *RT = dyn_cast(T)) { if (RT->getDecl()->getName() != "TempStorage") @@ -4077,7 +4095,8 @@ static bool isCubTempStorageType(const clang::Type *T) { bool isCubTempStorageType(QualType T) { if (T.isNull()) return false; - return isCubTempStorageType(T.getCanonicalType().getTypePtrOrNull()); + return isCubTempStorageType(T.getTypePtrOrNull()) || + isCubTempStorageType(T.getCanonicalType().getTypePtrOrNull()); } bool isCubCollectiveRecordType(QualType T) { @@ -4086,6 +4105,74 @@ bool isCubCollectiveRecordType(QualType T) { return isCubCollectiveRecordType(T.getCanonicalType().getTypePtrOrNull()); } +bool isPreserveCubVar(QualType T) { + auto isPreserve = [&](QualType QT) { + std::string ObjectName; + if (auto TypePtr = QT.getCanonicalType().getTypePtrOrNull()) { + if (auto *RT = dyn_cast(TypePtr)) { + auto *DC = RT->getDecl()->getDeclContext(); + if (DC && DC->isRecord()) { + auto *RD = dyn_cast(DC); + ObjectName = RD->getNameAsString(); + } + } else if (auto *DNT = dyn_cast(TypePtr)) { + auto *QNNS = DNT->getQualifier(); + if (QNNS->getKind() == NestedNameSpecifier::TypeSpec) { + if (auto *SpecType = + dyn_cast(QNNS->getAsType())) { + ObjectName = SpecType->getTemplateName() + .getAsTemplateDecl() + ->getNameAsString(); + } + } + } + } + if (ObjectName.empty()) { + const clang::TypedefType *TT = dyn_cast(T.getTypePtr()); + if (!TT) { + if (auto ET = dyn_cast_or_null(QT.getTypePtrOrNull())) { + TT = dyn_cast(ET->desugar().getTypePtr()); + } + } + if (TT) { + if (auto D = TT->getDecl()) { + auto DC = D->getDeclContext(); + if (DC && DC->isRecord()) { + auto *RD = dyn_cast(DC); + ObjectName = RD->getNameAsString(); + } + } + } + } + + if ((ObjectName.find("BlockLoad") != std::string::npos) || + (ObjectName.find("BlockStore") != std::string::npos) || + (ObjectName.find("BlockExchange") != std::string::npos) || + (ObjectName.find("BlockRadixSort") != std::string::npos)) { + return true; + } + return false; + }; + if (DpctGlobalInfo::useGroupLocalMemory()) { + if (isCubTempStorageType(T)) { + if (isPreserve(T)) { + return true; + } + } + if (T->isUnionType()) { + const TagDecl *RD = T->getAsUnionType()->getDecl()->getCanonicalDecl(); + for (const auto *D : RD->decls()) { + if (const auto *FD = dyn_cast(D)) { + auto QT = FD->getType().getCanonicalType(); + if (isCubTempStorageType(QT.getTypePtrOrNull()) && isPreserve(QT)) + return true; + } + } + } + } + return false; +} + bool isCubVar(const VarDecl *VD) { QualType CanType = VD->getType().getCanonicalType(); std::string CanonicalTypeStr = CanType.getAsString(); diff --git a/clang/lib/DPCT/Utility.h b/clang/lib/DPCT/Utility.h index d8f35cbf018c..7c671f62fc70 100644 --- a/clang/lib/DPCT/Utility.h +++ b/clang/lib/DPCT/Utility.h @@ -524,6 +524,7 @@ bool isDefaultStream(const clang::Expr *StreamArg); bool isRedeclInCUDAHeader(const clang::TypedefType *T); bool isTypeInAnalysisScope(const clang::Type *TypePtr); bool isCubVar(const clang::VarDecl *VD); +bool isPreserveCubVar(QualType T); bool isCubTempStorageType(QualType T); bool isCubCollectiveRecordType(QualType T); bool isExprUsed(const clang::Expr *E, bool &Result); diff --git a/clang/runtime/dpct-rt/include/dpct/detail/group_utils_detail.hpp b/clang/runtime/dpct-rt/include/dpct/detail/group_utils_detail.hpp index db53890fc77f..6063210cf66a 100644 --- a/clang/runtime/dpct-rt/include/dpct/detail/group_utils_detail.hpp +++ b/clang/runtime/dpct-rt/include/dpct/detail/group_utils_detail.hpp @@ -28,14 +28,29 @@ template struct log2 { enum { VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1 }; }; -template class radix_rank { +template +class radix_rank { + static constexpr int PACKING_RATIO = + sizeof(packed_counter_type) / sizeof(digit_counter_type); + static constexpr int LOG_PACKING_RATIO = log2::VALUE; + static constexpr int LOG_COUNTER_LANES = RADIX_BITS - LOG_PACKING_RATIO; + static constexpr int COUNTER_LANES = 1 << LOG_COUNTER_LANES; + static constexpr int PADDED_COUNTER_LANES = COUNTER_LANES + 1; + public: + struct TempLocalMemory { + static constexpr int group_threads = + group_dim_0 * group_dim_1 * group_dim_2; + uint8_t data[group_threads * PADDED_COUNTER_LANES * + sizeof(packed_counter_type)]; + }; static size_t get_local_memory_size(size_t group_threads) { return group_threads * PADDED_COUNTER_LANES * sizeof(packed_counter_type); } radix_rank(uint8_t *local_memory) : _local_memory(local_memory) {} - + radix_rank(TempLocalMemory &temp) { _local_memory = &(temp.data[0]); } template __dpct_inline__ void rank_keys(const Item &item, KT (&keys)[VALUES_PER_THREAD], @@ -160,13 +175,6 @@ template class radix_rank { } private: - static constexpr int PACKING_RATIO = - sizeof(packed_counter_type) / sizeof(digit_counter_type); - static constexpr int LOG_PACKING_RATIO = log2::VALUE; - static constexpr int LOG_COUNTER_LANES = RADIX_BITS - LOG_PACKING_RATIO; - static constexpr int COUNTER_LANES = 1 << LOG_COUNTER_LANES; - static constexpr int PADDED_COUNTER_LANES = COUNTER_LANES + 1; - packed_counter_type cached_segment[PADDED_COUNTER_LANES]; uint8_t *_local_memory; }; diff --git a/clang/runtime/dpct-rt/include/dpct/group_utils.hpp b/clang/runtime/dpct-rt/include/dpct/group_utils.hpp index 804ebfc3cac2..7452ff78ea81 100644 --- a/clang/runtime/dpct-rt/include/dpct/group_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/group_utils.hpp @@ -25,8 +25,29 @@ namespace group { /// \tparam T The type of the data elements. /// \tparam ElementsPerWorkItem The number of data elements assigned to a /// work-item. -template class exchange { +/// \tparam group_dim_0 The first dimension size of the work-group. +/// \tparam group_dim_1 The second dimension size of the work-group. +/// \tparam group_dim_2 The third dimension size of the work-group. +template +class exchange { + static constexpr int LOG_LOCAL_MEMORY_BANKS = 4; + static constexpr bool INSERT_PADDING = + (ElementsPerWorkItem > 4) && + (detail::power_of_two::VALUE); + public: + struct TempLocalMemory { + static constexpr int group_threads = + group_dim_0 * group_dim_1 * group_dim_2; + static constexpr int padding_values = + INSERT_PADDING + ? ((group_threads * ElementsPerWorkItem) >> LOG_LOCAL_MEMORY_BANKS) + : 0; + static constexpr int total_elements = + group_threads * ElementsPerWorkItem + padding_values; + uint8_t data[total_elements * sizeof(T)]; + }; static size_t get_local_memory_size(size_t group_threads) { size_t padding_values = (INSERT_PADDING) @@ -36,7 +57,7 @@ template class exchange { } exchange(uint8_t *local_memory) : _local_memory(local_memory) {} - + exchange(TempLocalMemory &temp) { _local_memory = &(temp.data[0]); } // TODO: Investigate if padding is required for performance, // and if specializations are required for specific target hardware. static size_t adjust_by_padding(size_t offset) { @@ -329,11 +350,6 @@ template class exchange { } } - static constexpr int LOG_LOCAL_MEMORY_BANKS = 4; - static constexpr bool INSERT_PADDING = - (ElementsPerWorkItem > 4) && - (detail::power_of_two::VALUE); - uint8_t *_local_memory; }; @@ -344,13 +360,28 @@ template class exchange { /// \tparam ElementsPerWorkItem The number of data elements assigned to /// a work-item. /// \tparam RADIX_BITS The number of radix bits per digit place. -template +/// \tparam group_dim_0 The first dimension size of the work-group. +/// \tparam group_dim_1 The second dimension size of the work-group. +/// \tparam group_dim_2 The third dimension size of the work-group. +template class group_radix_sort { uint8_t *_local_memory; public: + struct TempLocalMemory { + static constexpr size_t radix_bytes = + sizeof(typename detail::radix_rank::TempLocalMemory); + static constexpr size_t exchange_bytes = + sizeof(typename exchange::TempLocalMemory); + static constexpr size_t max_bytes = + (radix_bytes > exchange_bytes) ? radix_bytes : exchange_bytes; + uint8_t data[max_bytes]; + }; group_radix_sort(uint8_t *local_memory) : _local_memory(local_memory) {} - + group_radix_sort(TempLocalMemory &temp) { _local_memory = &(temp.data[0]); } static size_t get_local_memory_size(size_t group_threads) { size_t ranks_size = detail::radix_rank::get_local_memory_size(group_threads); @@ -1107,21 +1138,39 @@ enum class group_load_algorithm { /// \tparam ElementsPerWorkItem The number of data elements assigned to a /// work-item. /// \tparam LoadAlgorithm The data movement strategy, default is blocked. +/// \tparam group_dim_0 The first dimension size of the work-group. +/// \tparam group_dim_1 The second dimension size of the work-group. +/// \tparam group_dim_2 The third dimension size of the work-group. template + group_load_algorithm LoadAlgorithm = group_load_algorithm::blocked, + int group_dim_0 = 1, int group_dim_1 = 1, int group_dim_2 = 1> class group_load { + struct _TempLocalMemory + : exchange::TempLocalMemory {}; + struct _NullTempLocalMemory {}; + static constexpr bool need_temp = + (LoadAlgorithm == group_load_algorithm::transpose) || + (LoadAlgorithm == group_load_algorithm::sub_group_transpose); + public: + using TempLocalMemory = typename std::conditional::type; static size_t get_local_memory_size(size_t work_group_size) { - if constexpr ((LoadAlgorithm == group_load_algorithm::transpose) || - (LoadAlgorithm == - group_load_algorithm::sub_group_transpose)) { + if constexpr (need_temp) { return dpct::group::exchange< T, ElementsPerWorkItem>::get_local_memory_size(work_group_size); } return 0; } group_load(uint8_t *local_memory) : _local_memory(local_memory) {} - + group_load(TempLocalMemory &temp) { + if constexpr (need_temp) { + _local_memory = &(temp.data[0]); + } else { + _local_memory = nullptr; + } + } /// Load a linear segment of items from memory. /// /// Suppose 512 integer data elements partitioned across 128 work-items, where @@ -1296,21 +1345,39 @@ enum class group_store_algorithm { /// \tparam ElementsPerWorkItem The number of data elements assigned to a /// work-item. /// \tparam StoreAlgorithm The data movement strategy, default is blocked. +/// \tparam group_dim_0 The first dimension size of the work-group. +/// \tparam group_dim_1 The second dimension size of the work-group. +/// \tparam group_dim_2 The third dimension size of the work-group. template + group_store_algorithm StoreAlgorithm = group_store_algorithm::blocked, + int group_dim_0 = 1, int group_dim_1 = 1, int group_dim_2 = 1> class group_store { + struct _TempLocalMemory + : exchange::TempLocalMemory {}; + struct _NullTempLocalMemory {}; + static constexpr bool need_temp = + (StoreAlgorithm == group_store_algorithm::transpose) || + (StoreAlgorithm == group_store_algorithm::sub_group_transpose); + public: + using TempLocalMemory = typename std::conditional::type; static size_t get_local_memory_size(size_t work_group_size) { - if constexpr ((StoreAlgorithm == group_store_algorithm::transpose) || - (StoreAlgorithm == - group_store_algorithm::sub_group_transpose)) { + if constexpr (need_temp) { return dpct::group::exchange< T, ElementsPerWorkItem>::get_local_memory_size(work_group_size); } return 0; } group_store(uint8_t *local_memory) : _local_memory(local_memory) {} - + group_store(TempLocalMemory &temp) { + if constexpr (need_temp) { + _local_memory = &(temp.data[0]); + } else { + _local_memory = nullptr; + } + } /// Store items into a linear segment of memory. /// /// Suppose 512 integer data elements partitioned across 128 work-items, where @@ -1415,18 +1482,22 @@ class group_store { /// \tparam group_dim_0 The first dimension size of the work-group. /// \tparam group_dim_1 The second dimension size of the work-group. /// \tparam group_dim_2 The third dimension size of the work-group. -template +template class group_shuffle { T *_local_memory = nullptr; static constexpr size_t group_work_items = group_dim_0 * group_dim_1 * group_dim_2; public: + struct TempLocalMemory { + uint8_t data[sizeof(T) * group_work_items]; + }; static constexpr size_t get_local_memory_size(size_t work_group_size) { return sizeof(T) * work_group_size; } group_shuffle(uint8_t *local_memory) : _local_memory((T *)local_memory) {} - + group_shuffle(TempLocalMemory &temp) { _local_memory = &(temp.data[0]); } /// Selects a value from a work-item at a given distance in the work-group /// and stores the value in the output. /// diff --git a/clang/test/dpct/cub/cub_with_local_memory_kernel_allocation.cu b/clang/test/dpct/cub/cub_with_local_memory_kernel_allocation.cu new file mode 100644 index 000000000000..9f3fd0442f66 --- /dev/null +++ b/clang/test/dpct/cub/cub_with_local_memory_kernel_allocation.cu @@ -0,0 +1,85 @@ + +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --use-experimental-features=local-memory-kernel-scope-allocation -format-range=none -in-root %S -out-root %T/cub_with_local_memory_kernel_allocation %S/cub_with_local_memory_kernel_allocation.cu --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck --input-file %T/cub_with_local_memory_kernel_allocation/cub_with_local_memory_kernel_allocation.dp.cpp --match-full-lines %s +// RUN: %if build_lit %{icpx -c -fsycl %T/cub_with_local_memory_kernel_allocation/cub_with_local_memory_kernel_allocation.dp.cpp -o %T/cub_with_local_memory_kernel_allocation/cub_with_local_memory_kernel_allocation.dp.o %} + +#include +#include + +// CHECK: template +// CHECK: void kernel(T *A) { +// CHECK: auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); +// CHECK: typedef dpct::group::group_load LoadFloat; +// CHECK: union type_ct1{ +// CHECK: typename LoadFloat::TempLocalMemory loadf; +// CHECK: void *reducef; +// CHECK: }; +// CHECK: auto &temp_storage = *sycl::ext::oneapi::group_local_memory_for_overwrite(sycl::ext::oneapi::this_work_item::get_work_group<3>()); +// CHECK: T vals[4]; +// CHECK: LoadFloat(temp_storage.loadf).load(item_ct1, &(A[0]), vals, 10); +// CHECK: auto &load = *sycl::ext::oneapi::group_local_memory_for_overwrite(sycl::ext::oneapi::this_work_item::get_work_group<3>()); +// CHECK: LoadFloat(load).load(item_ct1, &(A[0]), vals, 10); +// CHECK: } + +template +__global__ void kernel(T *A) { + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reducef; + } temp_storage; + + T vals[4]; + + LoadFloat(temp_storage.loadf).Load(&(A[0]), vals, 10); + + __shared__ typename LoadFloat::TempStorage load; + + LoadFloat(load).Load(&(A[0]), vals, 10); +} + +// CHECK: void foo() { +// CHECK: typedef dpct::group::group_load LoadFloat; +// CHECK: union type_ct2{ +// CHECK: typename LoadFloat::TempLocalMemory loadf; +// CHECK: }; +// CHECK: auto &temp_storage = *sycl::ext::oneapi::group_local_memory_for_overwrite(sycl::ext::oneapi::this_work_item::get_work_group<3>()); +// CHECK: int vals[4]; +// CHECK: LoadFloat(temp_storage.loadf).load(sycl::ext::oneapi::this_work_item::get_nd_item<3>(), vals, vals, 10); +// CHECK: auto &loadf2 = *sycl::ext::oneapi::group_local_memory_for_overwrite(sycl::ext::oneapi::this_work_item::get_work_group<3>()); +// CHECK: } +__global__ void foo() { + typedef cub::BlockLoad LoadFloat; + __shared__ union { + typename LoadFloat::TempStorage loadf; + } temp_storage; + int vals[4]; + LoadFloat(temp_storage.loadf).Load(vals, vals, 10); + + __shared__ typename LoadFloat::TempStorage loadf2; + +} +// CHECK: int main() { +// CHECK: sycl::device dev_ct1; +// CHECK: sycl::queue q_ct1(dev_ct1, sycl::property_list{sycl::property::queue::in_order()}); +// CHECK: q_ct1.parallel_for( +// CHECK: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)), +// CHECK: [=](sycl::nd_item<3> item_ct1) { +// CHECK: foo(); +// CHECK: }); +// CHECK: q_ct1.parallel_for( +// CHECK: sycl::nd_range<3>(sycl::range<3>(1, 1, 1), sycl::range<3>(1, 1, 1)), +// CHECK: [=](sycl::nd_item<3> item_ct1) { +// CHECK: kernel(0); +// CHECK: }); +// CHECK: return 0; +// CHECK: } +int main() { + foo<<<1, 1>>>(); + kernel<<<1,1>>>(0); + return 0; +}