Skip to content

Commit c4e65d3

Browse files
committed
review comment
1 parent daa7c9d commit c4e65d3

File tree

11 files changed

+110
-98
lines changed

11 files changed

+110
-98
lines changed

csrc/device_lower/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ bool isTvOp(const Expr* expr) {
125125
SliceOp,
126126
CatOp,
127127
ScanOp,
128-
GroupedBlockScalingFactorLayoutOp,
128+
PreprocessGroupedMatmulInputSf,
129129
kir::AllocTMem,
130130
kir::GridReduction,
131131
kir::GroupedGridReduction,

csrc/dispatch.h

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -68,60 +68,60 @@ class Val;
6868
#define DISPATCH_FOR_ALL_KIR_VALS(f) f(Predicate) f(TensorIndex)
6969
#define DISPATCH_FOR_ALL_HIR_VALS(f) f(Stream)
7070

71-
#define DISPATCH_FOR_ALL_EXPRS(f) \
72-
f(FullOp); \
73-
f(IotaOp); \
74-
f(EyeOp); \
75-
f(UnaryOp); \
76-
f(BinaryOp); \
77-
f(TernaryOp); \
78-
f(ArrayConstruct); \
79-
f(StructConstruct); \
80-
f(GetAttr); \
81-
f(GetItem); \
82-
f(ReverseArray); \
83-
f(GetMetaData); \
84-
f(TensorConstruct); \
85-
f(SelectOp); \
86-
f(IndexSelectOp); \
87-
f(IndexPutAccumulateOp); \
88-
f(GatherOp); \
89-
f(ScatterOp); \
90-
f(RNGOp); \
91-
f(ReductionOp); \
92-
f(GroupedReductionOp); \
93-
f(WelfordOp); \
94-
f(GroupedWelfordOp); \
95-
f(LoadStoreOp); \
96-
f(MmaOp); \
97-
f(BroadcastOp); \
98-
f(SqueezeOp); \
99-
f(ExpandOp); \
100-
f(RepeatOp); \
101-
f(ViewAsScalar); \
102-
f(ReshapeOp); \
103-
f(CatOp); \
104-
f(PadOp); \
105-
f(SliceOp); \
106-
f(Split); \
107-
f(ArgsortOp); \
108-
f(GroupedMmaOp); \
109-
f(ScaledMmaOp); \
110-
f(CutlassNvfp4GroupedMmaOp); \
111-
f(GroupedBlockScalingFactorLayoutOp); \
112-
f(TopKOp); \
113-
f(ScanOp); \
114-
f(Merge); \
115-
f(Swizzle); \
116-
f(Swizzle2D); \
117-
f(Resize); \
118-
f(MatmulOp); \
119-
f(LinearOp); \
120-
f(SdpaFwdOp); \
121-
f(SdpaBwdOp); \
122-
f(EmbeddingFwdOp); \
123-
f(Communication); \
124-
f(ForLoop); \
71+
#define DISPATCH_FOR_ALL_EXPRS(f) \
72+
f(FullOp); \
73+
f(IotaOp); \
74+
f(EyeOp); \
75+
f(UnaryOp); \
76+
f(BinaryOp); \
77+
f(TernaryOp); \
78+
f(ArrayConstruct); \
79+
f(StructConstruct); \
80+
f(GetAttr); \
81+
f(GetItem); \
82+
f(ReverseArray); \
83+
f(GetMetaData); \
84+
f(TensorConstruct); \
85+
f(SelectOp); \
86+
f(IndexSelectOp); \
87+
f(IndexPutAccumulateOp); \
88+
f(GatherOp); \
89+
f(ScatterOp); \
90+
f(RNGOp); \
91+
f(ReductionOp); \
92+
f(GroupedReductionOp); \
93+
f(WelfordOp); \
94+
f(GroupedWelfordOp); \
95+
f(LoadStoreOp); \
96+
f(MmaOp); \
97+
f(BroadcastOp); \
98+
f(SqueezeOp); \
99+
f(ExpandOp); \
100+
f(RepeatOp); \
101+
f(ViewAsScalar); \
102+
f(ReshapeOp); \
103+
f(CatOp); \
104+
f(PadOp); \
105+
f(SliceOp); \
106+
f(Split); \
107+
f(ArgsortOp); \
108+
f(GroupedMmaOp); \
109+
f(ScaledMmaOp); \
110+
f(CutlassNvfp4GroupedMmaOp); \
111+
f(PreprocessGroupedMatmulInputSf); \
112+
f(TopKOp); \
113+
f(ScanOp); \
114+
f(Merge); \
115+
f(Swizzle); \
116+
f(Swizzle2D); \
117+
f(Resize); \
118+
f(MatmulOp); \
119+
f(LinearOp); \
120+
f(SdpaFwdOp); \
121+
f(SdpaBwdOp); \
122+
f(EmbeddingFwdOp); \
123+
f(Communication); \
124+
f(ForLoop); \
125125
f(P2PCommunication);
126126
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
127127
f(Allocate); \

csrc/ir/internal_nodes.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3442,38 +3442,46 @@ class CutlassNvfp4GroupedMmaOp : public Expr {
34423442
}
34433443
};
34443444

3445-
//! NOTE -- [ GroupedBlockScalingFactorLayoutOp ]
3445+
//! NOTE -- [ PreprocessGroupedMatmulInputSf ]
34463446
//!
34473447
//! This operation performs a layout change on the input, it's currently used
34483448
//! for block scaling factor accompanying narrow precision inputs.
34493449
//!
3450-
//! 1. This can be viewed as a point-wise operation, where output loop domain
3451-
//! would match the input logical domain.
3450+
//! PreprocessGroupedMatmulInputSf(TensorView* output, TensorView* input, ...)
3451+
//!
3452+
//! input: logical domain: (i0, i1)
3453+
//! output: root domain: (i0, i1)
3454+
//! logical domain: (i2, i3)
3455+
//! loop domain: (i0, i1)
3456+
//!
3457+
//! 1. This can be viewed as a point-wise operation, since output loop domain
3458+
//! matches the input logical domain.
3459+
//!
34523460
//! 2. Because of the potential padding/swizzle, the logical domain of the
34533461
//! output does not map to input. We don't rely on codegen for indexing, so we
34543462
//! don't care about mapping the logical/allocation of output to anything else.
3455-
//! Indexing will be done in runtime function, utilizing `expert_offsets` and
3456-
//! `sf_offsets`.
3457-
//! 3. Output has a root domain, which is identical to its loop domain. We add
3458-
//! this so we can map it to input.
3459-
class GroupedBlockScalingFactorLayoutOp : public Expr {
3463+
//! Indexing will be done in runtime function, utilizing `input_offsets` and
3464+
//! `output_offsets`.
3465+
//!
3466+
//! 3. Output has a root domain that matches the logical domain of the input.
3467+
class PreprocessGroupedMatmulInputSf : public Expr {
34603468
public:
34613469
using Expr::Expr;
34623470

3463-
GroupedBlockScalingFactorLayoutOp(
3471+
PreprocessGroupedMatmulInputSf(
34643472
IrBuilderPasskey,
34653473
Val* output,
34663474
Val* input,
3467-
Val* expert_offsets,
3468-
Val* sf_offsets,
3475+
Val* input_offsets,
3476+
Val* output_offsets,
34693477
BlockScalingFactorLayout layout,
34703478
Val* k,
34713479
Val* g);
34723480

34733481
NVFUSER_DECLARE_CLONE_AND_CREATE
34743482

34753483
const char* getOpString() const override {
3476-
return "GroupedBlockScalingFactorLayoutOp";
3484+
return "PreprocessGroupedMatmulInputSf";
34773485
}
34783486

34793487
std::string toString(int indent_size = 0) const override;

csrc/ir/nodes.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6555,7 +6555,7 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
65556555

65566556
NVFUSER_DEFINE_CLONE_AND_CREATE(CutlassNvfp4GroupedMmaOp)
65576557

6558-
GroupedBlockScalingFactorLayoutOp::GroupedBlockScalingFactorLayoutOp(
6558+
PreprocessGroupedMatmulInputSf::PreprocessGroupedMatmulInputSf(
65596559
IrBuilderPasskey passkey,
65606560
Val* output,
65616561
Val* input,
@@ -6574,11 +6574,11 @@ GroupedBlockScalingFactorLayoutOp::GroupedBlockScalingFactorLayoutOp(
65746574
addDataAttribute(layout);
65756575
}
65766576

6577-
std::string GroupedBlockScalingFactorLayoutOp::toString(int indent_size) const {
6577+
std::string PreprocessGroupedMatmulInputSf::toString(int indent_size) const {
65786578
std::stringstream ss;
65796579
indent(ss, indent_size) << output(0)->toString() << "\n";
65806580
indent_size++;
6581-
indent(ss, indent_size) << " = grouped_block_scaling_factor_layout(\n";
6581+
indent(ss, indent_size) << " = preprocessGroupedMatmulInputSf(\n";
65826582
indent_size++;
65836583
indent(ss, indent_size) << "input = " << in()->toString() << ",\n";
65846584
indent(ss, indent_size) << "expert_offsets = " << expertOffsets()->toString()
@@ -6595,19 +6595,18 @@ std::string GroupedBlockScalingFactorLayoutOp::toString(int indent_size) const {
65956595
return ss.str();
65966596
}
65976597

6598-
std::string GroupedBlockScalingFactorLayoutOp::toInlineString(
6598+
std::string PreprocessGroupedMatmulInputSf::toInlineString(
65996599
int indent_size) const {
6600-
NVF_CHECK(
6601-
false, "GroupedBlockScalingFactorLayoutOp can not be printed inline");
6600+
NVF_CHECK(false, "PreprocessGroupedMatmulInputSf can not be printed inline");
66026601
}
66036602

6604-
std::vector<PolymorphicValue> GroupedBlockScalingFactorLayoutOp::evaluate(
6603+
std::vector<PolymorphicValue> PreprocessGroupedMatmulInputSf::evaluate(
66056604
const ExpressionEvaluator& ee,
66066605
const std::vector<PolymorphicValue>& inputs) const {
66076606
// This is a placeholder, currently we don't have a fallback kernel available
6608-
NVF_THROW("GroupedBlockScalingFactorLayoutOp evaluation not yet implemented");
6607+
NVF_THROW("PreprocessGroupedMatmulInputSf evaluation not yet implemented");
66096608
}
66106609

6611-
NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedBlockScalingFactorLayoutOp)
6610+
NVFUSER_DEFINE_CLONE_AND_CREATE(PreprocessGroupedMatmulInputSf)
66126611

66136612
} // namespace nvfuser

csrc/logical_domain_map.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ std::pair<std::unordered_set<IterDomain*>, bool> getNonMappingDomainInfo(
133133
}
134134
} else if (
135135
auto* grouped_block_sf_layout =
136-
dynamic_cast<GroupedBlockScalingFactorLayoutOp*>(
136+
dynamic_cast<PreprocessGroupedMatmulInputSf*>(
137137
consumer_tv->definition())) {
138138
if (producer_tv != grouped_block_sf_layout->in()) {
139139
auto producer_logical =

csrc/ops/indexing.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,10 @@ TensorView* takeAlongAxis(TensorView* inp, TensorView* index, int64_t dim) {
292292
return out_tensor->as<TensorView>();
293293
}
294294

295-
TensorView* groupedBlockSfLayout(
295+
TensorView* preprocessGroupedMatmulInputSf(
296296
TensorView* input,
297-
TensorView* expert_offsets,
298-
TensorView* sf_offsets,
297+
TensorView* input_offsets,
298+
TensorView* output_offsets,
299299
BlockScalingFactorLayout layout) {
300300
// only support input matrix;
301301
auto input_logical_dom =
@@ -312,9 +312,6 @@ TensorView* groupedBlockSfLayout(
312312
});
313313

314314
// Create the logical domain of output.
315-
// Note: output logical domain handles potential padding required for the
316-
// layout. Since the actual padding size is data-dependent, we allocate for
317-
// the maximum padding (reflected on logical/allocation domain).
318315
std::vector<IterDomain*> out_logical;
319316
out_logical.reserve(input_logical_dom.size());
320317

@@ -325,11 +322,15 @@ TensorView* groupedBlockSfLayout(
325322

326323
auto* one_val = input->fusion()->oneVal(DataType::Index);
327324
std::vector<IterDomain*> offset_logical_dom =
328-
TensorDomain::noReductions(expert_offsets->getLogicalDomain());
325+
TensorDomain::noReductions(input_offsets->getLogicalDomain());
329326
Val* num_groups =
330327
SimplifyingIrBuilder::subExpr(offset_logical_dom[0]->extent(), one_val);
331-
// padded row size:
332-
// num_groups * (row_multiple - 1) + row_size
328+
329+
// Note: output logical domain handles potential padding required for the
330+
// layout. Since the actual padding size is data-dependent, we allocate for
331+
// the maximum padding (reflected on logical/allocation domain).
332+
333+
// pad row size: num_groups * (row_multiple - 1) + row_size
333334
auto pad_to_max_extent = [&](IterDomain* id, int multiple) -> IterDomain* {
334335
auto* maximum_pad_value_per_group =
335336
IrBuilder::create<Val>(multiple - 1, DataType::Index);
@@ -340,8 +341,7 @@ TensorView* groupedBlockSfLayout(
340341
};
341342
out_logical.push_back(pad_to_max_extent(out_root[0], row_multiple));
342343

343-
// padded col size:
344-
// (col_size + col_multiple - 1) / col_multiple * col_multiple
344+
// pad col size: (col_size + col_multiple - 1) / col_multiple * col_multiple
345345
auto pad_to_multiple = [&](IterDomain* id, int multiple) -> IterDomain* {
346346
Val* ext = id->extent();
347347
auto* multiple_val = IrBuilder::create<Val>(multiple, DataType::Index);
@@ -370,11 +370,11 @@ TensorView* groupedBlockSfLayout(
370370
/*skip_checks=*/true),
371371
input->getDataType().value());
372372

373-
IrBuilder::create<GroupedBlockScalingFactorLayoutOp>(
373+
IrBuilder::create<PreprocessGroupedMatmulInputSf>(
374374
out_tv,
375375
input,
376-
expert_offsets,
377-
sf_offsets,
376+
input_offsets,
377+
output_offsets,
378378
layout,
379379
input_logical_dom[1]->getMaybeExpandedExtent(),
380380
num_groups);

csrc/ops/indexing.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,13 @@ NVF_API TensorView* takeAlongAxis(
8383
TensorView* index,
8484
int64_t dim);
8585

86-
NVF_API TensorView* groupedBlockSfLayout(
86+
//! Changes the layout of input to satisfy the requirement of grouped matmul on
87+
//! block scaling factor. see:
88+
//! https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts
89+
NVF_API TensorView* preprocessGroupedMatmulInputSf(
8790
TensorView* input,
88-
TensorView* expert_offsets,
89-
TensorView* sf_offsets,
91+
TensorView* input_offsets,
92+
TensorView* output_offsets,
9093
BlockScalingFactorLayout layout);
9194

9295
} // namespace nvfuser

csrc/scheduler/registry.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
5050
ScaledMmaOp,
5151
CutlassNvfp4GroupedMmaOp,
5252
// TODO: remove this once we have a scheduler for it
53-
GroupedBlockScalingFactorLayoutOp,
53+
PreprocessGroupedMatmulInputSf,
5454
TopKOp,
5555
ScanOp>(fusion)) {
5656
scheduler_debug_utils::canScheduleRejectReason(

csrc/type.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,8 @@ const char* block_sf_layout2string(BlockScalingFactorLayout t) {
338338
switch (t) {
339339
case BlockScalingFactorLayout::Block128x4:
340340
return "block_128_4";
341-
default:
342-
NVF_THROW("No string found for layout.");
343341
}
342+
std::unreachable();
344343
}
345344

346345
const char* predicate_type2string(PredicateType t) {

csrc/type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,9 @@ std::ostream& operator<<(std::ostream&, TMemRegisterDataPath);
11761176

11771177
std::ostream& operator<<(std::ostream&, cudaDriverEntryPointQueryResult);
11781178

1179+
// Layout for block scaling factor used by mx-format with narrow precision, this
1180+
// indicates how to index into block scaling factor. see:
1181+
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts
11791182
enum class BlockScalingFactorLayout {
11801183
Block128x4,
11811184
};

0 commit comments

Comments
 (0)