@@ -32,27 +32,18 @@ class ParamLookupTable {
3232
3333  static  ArrayRef<ParamsType> lookup (StringRef arch, KernelType op,
3434                                     Type dataType) {
35-     static  const  auto  &table = getTable ();
36- 
35+     arch = getArchName (arch);
3736    auto  key = makeKey (arch, op, dataType);
3837    LLVM_DEBUG (llvm::dbgs ()
3938               << " Lookup for tuning parameters with key " " \n " 
4039
40+     static  const  auto  &table = getTable ();
4141    auto  it = table.find (key);
4242    if  (it != table.end ()) {
4343      return  ArrayRef<ParamsType>(it->second .first , it->second .second );
4444    }
4545
46-     auto  archFamily = getArchName (arch).drop_back (2 ).str ();
47- 
48-     std::string fallbackKey;
49-     for  (const  auto  &entry : table) {
50-       if  (entry.first .find (archFamily) != std::string::npos &&
51-           entry.first .find (makeSuffix (op, dataType)) != std::string::npos) {
52-         fallbackKey = std::max (fallbackKey, entry.first );
53-       }
54-     }
55- 
46+     auto  fallbackKey = findFallback (arch, op, dataType);
5647    if  (!fallbackKey.empty ()) {
5748      LLVM_DEBUG (llvm::dbgs () << " Falling back to tuning parameters with key " 
5849                              << fallbackKey << " \n " 
@@ -71,33 +62,82 @@ class ParamLookupTable {
7162      llvm_unreachable (" Invalid architecture string" 
7263    }
7364    auto  remaining = arch.substr (gfxPos);
74-     auto  endPos = remaining.find_first_not_of (" 0123456789a" 3 );
65+     auto  endPos =
66+         remaining.find_if_not ([](char  c) { return  llvm::isAlnum (c); }, 3 );
7567    return  remaining.substr (0 , endPos);
7668  }
7769
7870  static  std::string getDataTypeString (Type dataType) {
7971    std::string dataTypeStr;
80-     llvm::raw_string_ostream os (dataTypeStr);
81-     os << dataType;
82- 
83-     if  (dataTypeStr == " bf16" 
84-       dataTypeStr = " f16" 
85-     } else  if  (dataTypeStr.find (" f8E" 
72+     if  (dataType.getIntOrFloatBitWidth () == 8  && isa<FloatType>(dataType)) {
73+       //  There are several 8-bit float types, but we use "fp8" generically
8674      dataTypeStr = " fp8" 
87-     } else  if  (dataType.isInteger () &&
88-                (dataTypeStr.at (0 ) == ' s' at (0 ) == ' u' 
89-       dataTypeStr = dataTypeStr.substr (1 );
75+     } else  if  (dataType.getIntOrFloatBitWidth () == 16  &&
76+                isa<FloatType>(dataType)) {
77+       //  We use "f16" for bf16 and f16 generically
78+       dataTypeStr = " f16" 
79+     } else  {
80+       llvm::raw_string_ostream os (dataTypeStr);
81+       os << dataType;
82+       if  (dataType.isInteger () &&
83+           (dataTypeStr.at (0 ) == ' s' at (0 ) == ' u' 
84+         //  Integer types can be printed as "sint" or "uint"
85+         dataTypeStr = dataTypeStr.substr (1 );
86+       }
9087    }
91- 
9288    return  dataTypeStr;
9389  }
9490
91+   //  Get all related archs sorted lexicographically
92+   static  std::vector<std::string> getRelatedArchs (StringRef arch, KernelType op,
93+                                                   Type dataType) {
94+     std::vector<std::string> archs;
95+     auto  prefix = arch.take_front (4 );
96+     auto  suffix = makeSuffix (op, dataType);
97+     static  const  auto  &table = getTable ();
98+     for  (const  auto  &entry : table) {
99+       if  (entry.first .find (prefix) != std::string::npos &&
100+           entry.first .rfind (suffix) != std::string::npos) {
101+         archs.push_back (entry.first .substr (0 , entry.first .find (' _' 
102+       }
103+     }
104+     std::sort (archs.begin (), archs.end ());
105+     return  archs;
106+   }
107+ 
108+   //  Search for fallback by truncating arch string
109+   //  e.g., gfx1151 -> gfx115 -> gfx11 -> gfx1
110+   static  std::string findFallback (StringRef arch, KernelType op,
111+                                   Type dataType) {
112+     const  auto  archs = getRelatedArchs (arch, op, dataType);
113+     return  findFallbackRecursive (arch.drop_back (1 ), op, dataType, archs);
114+   }
115+ 
116+   static  std::string
117+   findFallbackRecursive (StringRef arch, KernelType op, Type dataType,
118+                         const  std::vector<std::string> &archs) {
119+     if  (arch == " gfx" 
120+       return  " " 
121+ 
122+     //  Archs is sorted lexicographically, so we can search in reverse order for
123+     //  the latest matching arch
124+     //  e.g., gfx950 matches gfx9 before gfx942
125+     auto  it =
126+         std::find_if (archs.rbegin (), archs.rend (), [&](const  std::string &a) {
127+           return  a.find (arch.str ()) != std::string::npos;
128+         });
129+     if  (it == archs.rend ())
130+       return  findFallbackRecursive (arch.drop_back (1 ), op, dataType, archs);
131+     else 
132+       return  makeKey (*it, op, dataType);
133+   }
134+ 
95135  static  std::string makeSuffix (KernelType op, Type dataType) {
96136    return  stringifyEnum (op).lower () + " _" getDataTypeString (dataType);
97137  }
98138
99139  static  std::string makeKey (StringRef arch, KernelType op, Type dataType) {
100-     return  getArchName ( arch) .str () + " _" makeSuffix (op, dataType);
140+     return  arch.str () + " _" makeSuffix (op, dataType);
101141  }
102142
103143  static  const  std::unordered_map<std::string, ParamArray> &getTable () {
0 commit comments