Skip to content

Commit 2c992ab

Browse files
committed
block layout op runtime function added
1 parent 3ccfbde commit 2c992ab

File tree

5 files changed

+115
-2
lines changed

5 files changed

+115
-2
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,7 @@ list(APPEND NVFUSER_RUNTIME_FILES
13301330
${NVFUSER_ROOT}/runtime/block_sync_atomic.cu
13311331
${NVFUSER_ROOT}/runtime/block_sync_default.cu
13321332
${NVFUSER_ROOT}/runtime/block_welford_outer.cu
1333+
${NVFUSER_ROOT}/runtime/block_layout.cu
13331334
${NVFUSER_ROOT}/runtime/broadcast.cu
13341335
${NVFUSER_ROOT}/runtime/casts.cu
13351336
${NVFUSER_ROOT}/runtime/cluster.cu

csrc/kernel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ class KernelIrScanner : private IrVisitor {
272272
summary_.has_argsort = true;
273273
}
274274

275+
void handle(GroupedBlockScalingFactorLayoutOp* aop) final {
276+
summary_.has_grouped_block_sf_layout = true;
277+
}
278+
275279
void handle(TopKOp* top) final {
276280
summary_.has_topk = true;
277281
}

csrc/kernel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ struct KernelSummary {
142142
//! Do we have any argsort op?
143143
bool has_argsort = false;
144144

145+
//! Do we have any grouped_block_sf_layout op?
146+
bool has_grouped_block_sf_layout = false;
147+
145148
//! Do we have any topk op?
146149
bool has_topk = false;
147150

csrc/runtime/compiled_kernel.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include <nvfuser_resources/basic_type_traits.h>
5959
#include <nvfuser_resources/bf16_support.h>
6060
#include <nvfuser_resources/bit.h>
61+
#include <nvfuser_resources/block_layout.h>
6162
#include <nvfuser_resources/block_reduction.h>
6263
#include <nvfuser_resources/block_sync_atomic.h>
6364
#include <nvfuser_resources/block_sync_default.h>
@@ -1158,7 +1159,8 @@ std::string _getStructuredCode(
11581159
std::string kernel_name,
11591160
bool has_argsort = false,
11601161
bool has_topk = false,
1161-
bool has_scan = false) {
1162+
bool has_scan = false,
1163+
bool has_block_layout = false) {
11621164
// generating cuda code;
11631165
std::string code = "";
11641166

@@ -1194,6 +1196,9 @@ std::string _getStructuredCode(
11941196
if (has_topk) {
11951197
code += nvfuser_resources::topk_cu;
11961198
}
1199+
if (has_block_layout) {
1200+
code += nvfuser_resources::block_layout_cu;
1201+
}
11971202

11981203
code += "\nnamespace " + CompiledKernel::kernelNamespace() + " {\n\n";
11991204
code += kernel_str;
@@ -1439,7 +1444,8 @@ std::string CompiledKernel::getStructuredCode() const {
14391444
kernelName(),
14401445
kernel()->summary().has_argsort,
14411446
kernel()->summary().has_topk,
1442-
kernel()->summary().has_scan);
1447+
kernel()->summary().has_scan,
1448+
kernel()->summary().has_grouped_block_sf_layout);
14431449
}
14441450

14451451
std::string CompiledKernel::disassembledKernelSASS() const {

runtime/block_layout.cu

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
9+
namespace nvf::block_layout {
10+
11+
namespace {
12+
13+
// TODO: simplify this maybe?!
14+
template <int BLOCK_ROW_OUTER, int BLOCK_ROW_INNER, int BLOCK_COL>
15+
__device__ nvfuser_index_t offsetAfterSwizzlePadding(
16+
const nvfuser_index_t row_idx,
17+
const nvfuser_index_t col_idx,
18+
const nvfuser_index_t padded_col_size) {
19+
constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER;
20+
21+
/* logical dimension of matrix [ row_size, col_size]
22+
*
23+
* while layout is decomposed as
24+
* [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ]
25+
* where
26+
* row_tile = row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER)
27+
* col_tile = col_size / BLOCK_COL
28+
*/
29+
nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE;
30+
31+
nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE;
32+
nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER;
33+
nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER;
34+
nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL;
35+
nvfuser_index_t col_block_idx = col_idx % BLOCK_COL;
36+
37+
/* layout for matrix [ row_size, col_size]
38+
* it is viewed
39+
* [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL]
40+
* then transposed with axis (1, 3)
41+
* [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL]
42+
* and then made contiguous
43+
*/
44+
constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL;
45+
constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE =
46+
BLOCK_ROW_INNER * BLOCK_COL;
47+
constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL;
48+
49+
return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE +
50+
col_tile_idx * COL_TILE_STRIDE +
51+
row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE +
52+
row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx;
53+
}
54+
55+
} // namespace
56+
57+
// TODO: I think we can actually not have this handled as an opaque function.
58+
template <
59+
typename T,
60+
typename Index_T,
61+
int BLOCK_ROW_OUTER,
62+
int BLOCK_ROW_INNER,
63+
int BLOCK_COL,
64+
int UNROLL_FACTOR>
65+
__device__ void groupedBlockLayout(
66+
T* output,
67+
const T* input,
68+
const nvfuser_index_t row_idx,
69+
const nvfuser_index_t col_idx,
70+
const Index_T* expert_offsets,
71+
const Index_T* output_offsets,
72+
const nvfuser_index_t row_size,
73+
const nvfuser_index_t col_size,
74+
const nvfuser_index_t group_size) {
75+
// find corresponding expert_id
76+
int expert_id = 0;
77+
for (int i = 0; i < group_size; ++i) {
78+
if (row_idx < expert_offsets[i + 1]) {
79+
expert_id = i;
80+
break;
81+
}
82+
}
83+
84+
// row idx for current matmul
85+
nvfuser_index_t c_row_idx = row_idx - expert_offsets[expert_id];
86+
nvfuser_index_t padded_col_size =
87+
(col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL;
88+
T* out_group_offset = output + output_offsets[expert_id] * padded_col_size;
89+
90+
// TODO: vectorized load/store; The logic could be simplified afterwards.
91+
for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) {
92+
nvfuser_index_t index =
93+
offsetAfterSwizzlePadding<BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL>(
94+
c_row_idx, col_idx + i, padded_col_size);
95+
out_group_offset[index] = input[i];
96+
}
97+
}
98+
99+
} // namespace nvf::block_layout

0 commit comments

Comments
 (0)