Skip to content

Commit 856f291

Browse files
committed
PR4: Adding automatic scheduler
test case added wip prevent cacheAndForkOutputs disabl cacheInputs for offsets TVs change domain stuff in reference TV revert unused changes err something isn't working right wip
1 parent f79ec33 commit 856f291

File tree

5 files changed

+56
-3
lines changed

5 files changed

+56
-3
lines changed

csrc/ir/utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,12 @@ bool isIndexSelectLookupTv(const TensorView* tv) {
777777
return true;
778778
}
779779
}
780+
if (expr->isA<PreprocessGroupedMatmulInputSf>()) {
781+
auto layout = expr->as<PreprocessGroupedMatmulInputSf>();
782+
if (tv == layout->inputOffsets() || tv == layout->outputOffsets()) {
783+
return true;
784+
}
785+
}
780786
}
781787
return false;
782788
}

csrc/scheduler/registry.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
4949
GroupedMmaOp,
5050
ScaledMmaOp,
5151
CutlassNvfp4GroupedMmaOp,
52-
// TODO: remove this once we have a scheduler for it
53-
PreprocessGroupedMatmulInputSf,
5452
TopKOp,
5553
ScanOp>(fusion)) {
5654
scheduler_debug_utils::canScheduleRejectReason(
5755
scheduler_type, "Has unsupported ops");
5856
return false;
5957
}
6058

59+
// TODO: check PreprocessGroupedMatmulInputSf's output is in global memory / fusion output
60+
6161
// Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul
6262
// scheduler.
6363
if (scheduler_type != SchedulerType::Matmul &&
@@ -72,6 +72,7 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
7272
scheduler_type, "Connected fusion graph check failed!");
7373
return false;
7474
}
75+
7576
if (IterDomainGraph(fusion, /*allow_self_mapping=*/true).hasSelfMapping()) {
7677
scheduler_debug_utils::canScheduleRejectReason(
7778
scheduler_type, "Iter domain graph check failed!");

csrc/scheduler/tools/domain_map.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ bool canIgnoreIndexedInputDomainID(
5858
->isBroadcast()) {
5959
return false;
6060
}
61+
} else if (auto layout = dynamic_cast<PreprocessGroupedMatmulInputSf*>(use)) {
62+
if (input_tv == layout->inputOffsets() || input_tv == layout->outputOffsets()) {
63+
continue;
64+
}
6165
} else {
6266
// If the input TV is used by any other ops
6367
return false;

csrc/scheduler/utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,8 @@ std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
13411341
if (output->definition() == nullptr ||
13421342
// the output of ScatterOp must on the global memory due to the random
13431343
// or atomic access.
1344-
output->definition()->isA<ScatterOp>()) {
1344+
output->definition()->isA<ScatterOp>() ||
1345+
output->definition()->isA<PreprocessGroupedMatmulInputSf>()) {
13451346
continue;
13461347
}
13471348
if (!output->uses().empty()) {

tests/cpp/test_layout_op.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,45 @@ TEST_F(LayoutOpTest, ManualKernel) {
134134
t2));
135135
}
136136

137+
TEST_F(LayoutOpTest, SchedulerKernel) {
138+
auto fusion_ptr = std::make_unique<Fusion>();
139+
Fusion& fusion = *fusion_ptr.get();
140+
FusionGuard fg(&fusion);
141+
142+
auto inp = makeSymbolicTensor(2);
143+
auto offsets = makeSymbolicTensor(1, DataType::Int32);
144+
auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32);
145+
fusion.addInput(inp);
146+
fusion.addInput(offsets);
147+
fusion.addInput(rounded_offsets);
148+
149+
auto inp_tv = set(inp);
150+
auto out_tv = preprocessGroupedMatmulInputSf(
151+
inp_tv, offsets, rounded_offsets, BlockScalingFactorLayout::Block128x4);
152+
// NOTE: output of preprocessGroupedMatmulInputSf needs to be on global
153+
// memory, because we do indexing on output inside the runtime function.
154+
fusion.addOutput(out_tv);
155+
156+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
157+
int m = 512;
158+
int k = 9; // note: padded column size would be 12
159+
auto t0 = at::randn({m, k}, options);
160+
// tokens per group are [100, 150, 262] respectively, so each group would be
161+
// padded to multiple of 128. Hence the total output row span would cover a
162+
// length of 128 + 256 + 384 = 768.
163+
auto t1 = at::tensor({0, 100, 250, 512}, options.dtype(at::kInt));
164+
auto t2 = at::tensor({0, 128, 384, 768}, options.dtype(at::kInt));
165+
166+
// naive scheduling.
167+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
168+
auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
169+
170+
ASSERT_TRUE(validateGroupedLayout(
171+
BlockScalingFactorLayout::Block128x4,
172+
outputs[0].as<at::Tensor>(),
173+
t0,
174+
t1,
175+
t2));
176+
}
177+
137178
} // namespace nvfuser

0 commit comments

Comments
 (0)