Skip to content

Commit 3bbfa38

Browse files
Improve fallback lookup and address code review comments.
1 parent 19b357f commit 3bbfa38

File tree

2 files changed

+71
-31
lines changed

2 files changed

+71
-31
lines changed

mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,18 @@ class ParamLookupTable {
3232

3333
static ArrayRef<ParamsType> lookup(StringRef arch, KernelType op,
3434
Type dataType) {
35-
static const auto &table = getTable();
36-
35+
arch = getArchName(arch);
3736
auto key = makeKey(arch, op, dataType);
3837
LLVM_DEBUG(llvm::dbgs()
3938
<< "Lookup for tuning parameters with key " << key << "\n");
4039

40+
static const auto &table = getTable();
4141
auto it = table.find(key);
4242
if (it != table.end()) {
4343
return ArrayRef<ParamsType>(it->second.first, it->second.second);
4444
}
4545

46-
auto archFamily = getArchName(arch).drop_back(2).str();
47-
48-
std::string fallbackKey;
49-
for (const auto &entry : table) {
50-
if (entry.first.find(archFamily) != std::string::npos &&
51-
entry.first.find(makeSuffix(op, dataType)) != std::string::npos) {
52-
fallbackKey = std::max(fallbackKey, entry.first);
53-
}
54-
}
55-
46+
auto fallbackKey = findFallback(arch, op, dataType);
5647
if (!fallbackKey.empty()) {
5748
LLVM_DEBUG(llvm::dbgs() << "Falling back to tuning parameters with key "
5849
<< fallbackKey << "\n");
@@ -71,33 +62,82 @@ class ParamLookupTable {
7162
llvm_unreachable("Invalid architecture string");
7263
}
7364
auto remaining = arch.substr(gfxPos);
74-
auto endPos = remaining.find_first_not_of("0123456789a", 3);
65+
auto endPos =
66+
remaining.find_if_not([](char c) { return llvm::isAlnum(c); }, 3);
7567
return remaining.substr(0, endPos);
7668
}
7769

7870
static std::string getDataTypeString(Type dataType) {
7971
std::string dataTypeStr;
80-
llvm::raw_string_ostream os(dataTypeStr);
81-
os << dataType;
82-
83-
if (dataTypeStr == "bf16") {
84-
dataTypeStr = "f16";
85-
} else if (dataTypeStr.find("f8E") != std::string::npos) {
72+
if (dataType.getIntOrFloatBitWidth() == 8 && isa<FloatType>(dataType)) {
73+
// There are several 8-bit float types, but we use "fp8" generically
8674
dataTypeStr = "fp8";
87-
} else if (dataType.isInteger() &&
88-
(dataTypeStr.at(0) == 's' || dataTypeStr.at(0) == 'u')) {
89-
dataTypeStr = dataTypeStr.substr(1);
75+
} else if (dataType.getIntOrFloatBitWidth() == 16 &&
76+
isa<FloatType>(dataType)) {
77+
// We use "f16" for bf16 and f16 generically
78+
dataTypeStr = "f16";
79+
} else {
80+
llvm::raw_string_ostream os(dataTypeStr);
81+
os << dataType;
82+
if (dataType.isInteger() &&
83+
(dataTypeStr.at(0) == 's' || dataTypeStr.at(0) == 'u')) {
84+
// Integer types can be printed as "sint" or "uint"
85+
dataTypeStr = dataTypeStr.substr(1);
86+
}
9087
}
91-
9288
return dataTypeStr;
9389
}
9490

91+
// Get all related archs sorted lexicographically
92+
static std::vector<std::string> getRelatedArchs(StringRef arch, KernelType op,
93+
Type dataType) {
94+
std::vector<std::string> archs;
95+
auto prefix = arch.take_front(4);
96+
auto suffix = makeSuffix(op, dataType);
97+
static const auto &table = getTable();
98+
for (const auto &entry : table) {
99+
if (entry.first.find(prefix) != std::string::npos &&
100+
entry.first.rfind(suffix) != std::string::npos) {
101+
archs.push_back(entry.first.substr(0, entry.first.find('_')));
102+
}
103+
}
104+
std::sort(archs.begin(), archs.end());
105+
return archs;
106+
}
107+
108+
// Search for fallback by truncating arch string
109+
// e.g., gfx1151 -> gfx115 -> gfx11 -> gfx1
110+
static std::string findFallback(StringRef arch, KernelType op,
111+
Type dataType) {
112+
const auto archs = getRelatedArchs(arch, op, dataType);
113+
return findFallbackRecursive(arch.drop_back(1), op, dataType, archs);
114+
}
115+
116+
static std::string
117+
findFallbackRecursive(StringRef arch, KernelType op, Type dataType,
118+
const std::vector<std::string> &archs) {
119+
if (arch == "gfx")
120+
return "";
121+
122+
// Archs is sorted lexicographically, so we can search in reverse order for
123+
// the latest matching arch
124+
// e.g., gfx950 matches gfx9 before gfx942
125+
auto it =
126+
std::find_if(archs.rbegin(), archs.rend(), [&](const std::string &a) {
127+
return a.find(arch.str()) != std::string::npos;
128+
});
129+
if (it == archs.rend())
130+
return findFallbackRecursive(arch.drop_back(1), op, dataType, archs);
131+
else
132+
return makeKey(*it, op, dataType);
133+
}
134+
95135
static std::string makeSuffix(KernelType op, Type dataType) {
96136
return stringifyEnum(op).lower() + "_" + getDataTypeString(dataType);
97137
}
98138

99139
static std::string makeKey(StringRef arch, KernelType op, Type dataType) {
100-
return getArchName(arch).str() + "_" + makeSuffix(op, dataType);
140+
return arch.str() + "_" + makeSuffix(op, dataType);
101141
}
102142

103143
static const std::unordered_map<std::string, ParamArray> &getTable() {

mlir/utils/performance/analysis/quickTuningGen.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def find(self):
338338
return result
339339

340340

341-
def combine_data(input_dir, no_split_k):
341+
def combine_data(input_dir, no_splitk):
342342
"""
343343
Combine all *.debug tuning data into a single file.
344344
"""
@@ -354,7 +354,7 @@ def combine_data(input_dir, no_split_k):
354354
new_df = pd.concat(dfs, ignore_index=True)
355355

356356
# Remove splitK from tuning data
357-
if no_split_k:
357+
if no_splitk:
358358
df_filtered = new_df[new_df['PerfConfig'].str.split(',').str[6] == '1']
359359
new_df = df_filtered
360360

@@ -373,8 +373,8 @@ def print_results(result):
373373

374374
def main(args=None):
375375
"""
376-
usage: quickTunerGen.py [-h] --input-dir INPUT_DIR --op {gemm,conv} [--th TH] --arch ARCH [--update] [--no-splitK]
377-
usage exsample: python3 quickTuningGen.py --input-dir tunedData --op conv --arch gfx90a --update --no-splitK
376+
usage: quickTunerGen.py [-h] --input-dir INPUT_DIR --op {gemm,conv} [--th TH] --arch ARCH [--update] [--no-splitk]
377+
usage exsample: python3 quickTuningGen.py --input-dir tunedData --op conv --arch gfx90a --update --no-splitk
378378
"""
379379
if args is None:
380380
args = sys.argv[1:]
@@ -391,14 +391,14 @@ def main(args=None):
391391

392392
parser.add_argument("--update", required=False, default=False, action='store_true')
393393

394-
parser.add_argument('--no-splitK',
394+
parser.add_argument('--no-splitk',
395395
default=False,
396396
action='store_true',
397-
help='Removing the spliK factor from the generated list')
397+
help='Removing the Split-K factor from the generated list')
398398

399399
pargs = parser.parse_args()
400400

401-
combined_data = combine_data(pargs.input_dir, pargs.no_split_k)
401+
combined_data = combine_data(pargs.input_dir, pargs.no_splitk)
402402

403403
finder = PerfConfigsFinder(combined_data, pargs)
404404
result = finder.find()

0 commit comments

Comments
 (0)