Skip to content

Commit 1d72d32

Browse files
committed
PR1: add layout op
1. adding Fusion node `GroupedBlockScalingFactorLayoutOp` GroupedBlockScalingFactorLayoutOp [output] Val* output (2d tensor) [input] TensorView* input (2d tensor) TensorView* expert_offsets (vector) TensorView* sf_offsets (vector) Val* k (scalar) Val* g (scalar) [attribute] BlockScalingFactorLayout layout 2. adding cpp api `groupedBlockSfLayout` TensorView* groupedBlockSfLayout( TensorView* input, TensorView* expert_offsets, TensorView* sf_offsets, BlockScalingFactorLayout layout);
1 parent 9bc115a commit 1d72d32

File tree

13 files changed

+365
-67
lines changed

13 files changed

+365
-67
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,9 @@ if(BUILD_TEST)
11381138
add_test(test_reshape "${NVFUSER_ROOT}/tests/cpp/test_reshape.cpp" "")
11391139
list(APPEND TEST_BINARIES test_reshape)
11401140

1141+
add_test(test_layout_op ${NVFUSER_ROOT}/tests/cpp/test_layout_op.cpp "")
1142+
list(APPEND TEST_BINARIES test_layout_op)
1143+
11411144
set(MATMUL_TEST_SRCS)
11421145
list(APPEND MATMUL_TEST_SRCS
11431146
${NVFUSER_ROOT}/tests/cpp/test_cutlass_scheduler.cpp

csrc/device_lower/utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ bool isTvOp(const Expr* expr) {
125125
SliceOp,
126126
CatOp,
127127
ScanOp,
128+
GroupedBlockScalingFactorLayoutOp,
128129
kir::AllocTMem,
129130
kir::GridReduction,
130131
kir::GroupedGridReduction,

csrc/dispatch.h

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -68,59 +68,60 @@ class Val;
6868
#define DISPATCH_FOR_ALL_KIR_VALS(f) f(Predicate) f(TensorIndex)
6969
#define DISPATCH_FOR_ALL_HIR_VALS(f) f(Stream)
7070

71-
#define DISPATCH_FOR_ALL_EXPRS(f) \
72-
f(FullOp); \
73-
f(IotaOp); \
74-
f(EyeOp); \
75-
f(UnaryOp); \
76-
f(BinaryOp); \
77-
f(TernaryOp); \
78-
f(ArrayConstruct); \
79-
f(StructConstruct); \
80-
f(GetAttr); \
81-
f(GetItem); \
82-
f(ReverseArray); \
83-
f(GetMetaData); \
84-
f(TensorConstruct); \
85-
f(SelectOp); \
86-
f(IndexSelectOp); \
87-
f(IndexPutAccumulateOp); \
88-
f(GatherOp); \
89-
f(ScatterOp); \
90-
f(RNGOp); \
91-
f(ReductionOp); \
92-
f(GroupedReductionOp); \
93-
f(WelfordOp); \
94-
f(GroupedWelfordOp); \
95-
f(LoadStoreOp); \
96-
f(MmaOp); \
97-
f(BroadcastOp); \
98-
f(SqueezeOp); \
99-
f(ExpandOp); \
100-
f(RepeatOp); \
101-
f(ViewAsScalar); \
102-
f(ReshapeOp); \
103-
f(CatOp); \
104-
f(PadOp); \
105-
f(SliceOp); \
106-
f(Split); \
107-
f(ArgsortOp); \
108-
f(GroupedMmaOp); \
109-
f(ScaledMmaOp); \
110-
f(CutlassNvfp4GroupedMmaOp); \
111-
f(TopKOp); \
112-
f(ScanOp); \
113-
f(Merge); \
114-
f(Swizzle); \
115-
f(Swizzle2D); \
116-
f(Resize); \
117-
f(MatmulOp); \
118-
f(LinearOp); \
119-
f(SdpaFwdOp); \
120-
f(SdpaBwdOp); \
121-
f(EmbeddingFwdOp); \
122-
f(Communication); \
123-
f(ForLoop); \
71+
#define DISPATCH_FOR_ALL_EXPRS(f) \
72+
f(FullOp); \
73+
f(IotaOp); \
74+
f(EyeOp); \
75+
f(UnaryOp); \
76+
f(BinaryOp); \
77+
f(TernaryOp); \
78+
f(ArrayConstruct); \
79+
f(StructConstruct); \
80+
f(GetAttr); \
81+
f(GetItem); \
82+
f(ReverseArray); \
83+
f(GetMetaData); \
84+
f(TensorConstruct); \
85+
f(SelectOp); \
86+
f(IndexSelectOp); \
87+
f(IndexPutAccumulateOp); \
88+
f(GatherOp); \
89+
f(ScatterOp); \
90+
f(RNGOp); \
91+
f(ReductionOp); \
92+
f(GroupedReductionOp); \
93+
f(WelfordOp); \
94+
f(GroupedWelfordOp); \
95+
f(LoadStoreOp); \
96+
f(MmaOp); \
97+
f(BroadcastOp); \
98+
f(SqueezeOp); \
99+
f(ExpandOp); \
100+
f(RepeatOp); \
101+
f(ViewAsScalar); \
102+
f(ReshapeOp); \
103+
f(CatOp); \
104+
f(PadOp); \
105+
f(SliceOp); \
106+
f(Split); \
107+
f(ArgsortOp); \
108+
f(GroupedMmaOp); \
109+
f(ScaledMmaOp); \
110+
f(CutlassNvfp4GroupedMmaOp); \
111+
f(GroupedBlockScalingFactorLayoutOp); \
112+
f(TopKOp); \
113+
f(ScanOp); \
114+
f(Merge); \
115+
f(Swizzle); \
116+
f(Swizzle2D); \
117+
f(Resize); \
118+
f(MatmulOp); \
119+
f(LinearOp); \
120+
f(SdpaFwdOp); \
121+
f(SdpaBwdOp); \
122+
f(EmbeddingFwdOp); \
123+
f(Communication); \
124+
f(ForLoop); \
124125
f(P2PCommunication);
125126
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
126127
f(Allocate); \

csrc/ir/internal_base_nodes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,8 @@ class NVF_API TensorDomain : public Val {
458458
std::vector<IterDomain*> loop_domain,
459459
std::optional<std::vector<IterDomain*>> alternate_loop_domain,
460460
std::vector<std::optional<bool>> contiguity = {},
461-
std::vector<IterDomain*> additional_ids = {});
461+
std::vector<IterDomain*> additional_ids = {},
462+
bool skip_checks = false);
462463

463464
TensorDomain(IrBuilderPasskey, const TensorDomain* src);
464465

csrc/ir/internal_nodes.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3442,4 +3442,61 @@ class CutlassNvfp4GroupedMmaOp : public Expr {
34423442
}
34433443
};
34443444

3445+
class GroupedBlockScalingFactorLayoutOp : public Expr {
3446+
public:
3447+
using Expr::Expr;
3448+
3449+
GroupedBlockScalingFactorLayoutOp(
3450+
IrBuilderPasskey,
3451+
Val* output,
3452+
Val* input,
3453+
Val* expert_offsets,
3454+
Val* sf_offsets,
3455+
BlockScalingFactorLayout layout,
3456+
Val* k,
3457+
Val* g);
3458+
3459+
NVFUSER_DECLARE_CLONE_AND_CREATE
3460+
3461+
const char* getOpString() const override {
3462+
return "GroupedBlockScalingFactorLayoutOp";
3463+
}
3464+
3465+
std::string toString(int indent_size = 0) const override;
3466+
std::string toInlineString(int indent_size = 0) const override;
3467+
std::vector<PolymorphicValue> evaluate(
3468+
const ExpressionEvaluator& ee,
3469+
const std::vector<PolymorphicValue>& inputs) const override;
3470+
3471+
// Get output block scaling factor
3472+
Val* out() const {
3473+
return output(0);
3474+
}
3475+
3476+
// Get input block scaling factor
3477+
Val* in() const {
3478+
return input(0);
3479+
}
3480+
3481+
TensorView* expertOffsets() const {
3482+
return input(1)->as<TensorView>();
3483+
}
3484+
3485+
TensorView* scalingFactorOffsets() const {
3486+
return input(2)->as<TensorView>();
3487+
}
3488+
3489+
Val* k() const {
3490+
return input(3);
3491+
}
3492+
3493+
Val* g() const {
3494+
return input(4);
3495+
}
3496+
3497+
BlockScalingFactorLayout layout() const {
3498+
return attribute<BlockScalingFactorLayout>(0);
3499+
}
3500+
};
3501+
34453502
} // namespace nvfuser

csrc/ir/nodes.cpp

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,7 +3349,8 @@ TensorDomain::TensorDomain(
33493349
std::vector<IterDomain*> loop_domain,
33503350
std::optional<std::vector<IterDomain*>> alternate_loop_domain,
33513351
std::vector<std::optional<bool>> contiguity,
3352-
std::vector<IterDomain*> additional_ids)
3352+
std::vector<IterDomain*> additional_ids,
3353+
bool skip_checks)
33533354
: Val(passkey, ValType::TensorDomain, DataType::Null),
33543355
root_domain_(std::move(root_domain)),
33553356
logical_domain_(std::move(logical_domain)),
@@ -3366,18 +3367,21 @@ TensorDomain::TensorDomain(
33663367
NVF_CHECK(
33673368
loop_domain_.empty() == logical_domain_.empty(),
33683369
"logical domain and loop domain can only be both empty or neither empty");
3369-
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_);
3370-
if (!root_domain_.empty()) {
3371-
ir_utils::validateDomainEquivalence(
3372-
logical_domain_, root_domain_, additional_ids_);
3373-
}
3374-
if (!allocation_domain_.empty()) {
3375-
ir_utils::validateDomainEquivalence(
3376-
logical_domain_, allocation_domain_, additional_ids_);
3377-
}
3378-
if (alternate_loop_domain_.has_value()) {
3379-
validateLoopDomain(
3380-
logical_domain_, alternate_loop_domain_.value(), additional_ids_);
3370+
3371+
if (!skip_checks) {
3372+
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_);
3373+
if (!root_domain_.empty()) {
3374+
ir_utils::validateDomainEquivalence(
3375+
logical_domain_, root_domain_, additional_ids_);
3376+
}
3377+
if (!allocation_domain_.empty()) {
3378+
ir_utils::validateDomainEquivalence(
3379+
logical_domain_, allocation_domain_, additional_ids_);
3380+
}
3381+
if (alternate_loop_domain_.has_value()) {
3382+
validateLoopDomain(
3383+
logical_domain_, alternate_loop_domain_.value(), additional_ids_);
3384+
}
33813385
}
33823386

33833387
// resetDomains initializes other member variables, required by clang-tidy
@@ -6551,4 +6555,60 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
65516555

65526556
NVFUSER_DEFINE_CLONE_AND_CREATE(CutlassNvfp4GroupedMmaOp)
65536557

6558+
GroupedBlockScalingFactorLayoutOp::GroupedBlockScalingFactorLayoutOp(
6559+
IrBuilderPasskey passkey,
6560+
Val* output,
6561+
Val* input,
6562+
Val* expert_offsets,
6563+
Val* sf_offsets,
6564+
BlockScalingFactorLayout layout,
6565+
Val* k,
6566+
Val* g)
6567+
: Expr(passkey) {
6568+
addInput(input);
6569+
addInput(expert_offsets);
6570+
addInput(sf_offsets);
6571+
addInput(k);
6572+
addInput(g);
6573+
addOutput(output);
6574+
addDataAttribute(layout);
6575+
}
6576+
6577+
std::string GroupedBlockScalingFactorLayoutOp::toString(int indent_size) const {
6578+
std::stringstream ss;
6579+
indent(ss, indent_size) << output(0)->toString() << "\n";
6580+
indent_size++;
6581+
indent(ss, indent_size) << " = grouped_block_scaling_factor_layout(\n";
6582+
indent_size++;
6583+
indent(ss, indent_size) << "input = " << in()->toString() << ",\n";
6584+
indent(ss, indent_size) << "expert_offsets = " << expertOffsets()->toString()
6585+
<< ",\n";
6586+
indent(ss, indent_size) << "sf_offsets = "
6587+
<< scalingFactorOffsets()->toString() << ",\n";
6588+
indent(ss, indent_size) << "layout = "
6589+
<< (layout() == BlockScalingFactorLayout::Block128x4
6590+
? "Block128x4"
6591+
: "Unknown")
6592+
<< "\n";
6593+
indent_size--;
6594+
indent(ss, indent_size) << ")\n";
6595+
return ss.str();
6596+
}
6597+
6598+
std::string GroupedBlockScalingFactorLayoutOp::toInlineString(
6599+
int indent_size) const {
6600+
NVF_CHECK(
6601+
false, "GroupedBlockScalingFactorLayoutOp can not be printed inline");
6602+
}
6603+
6604+
std::vector<PolymorphicValue> GroupedBlockScalingFactorLayoutOp::evaluate(
6605+
const ExpressionEvaluator& ee,
6606+
const std::vector<PolymorphicValue>& inputs) const {
6607+
// This is a placeholder implementation - the actual implementation
6608+
// would depend on the specific block scaling factor layout operation
6609+
NVF_THROW("GroupedBlockScalingFactorLayoutOp evaluation not yet implemented");
6610+
}
6611+
6612+
NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedBlockScalingFactorLayoutOp)
6613+
65546614
} // namespace nvfuser

csrc/logical_domain_map.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ std::pair<std::unordered_set<IterDomain*>, bool> getNonMappingDomainInfo(
131131
non_mapping_ids.insert(producer_logical.at(topk_dim));
132132
has_consumer_id = true;
133133
}
134+
} else if (
135+
auto* grouped_block_sf_layout =
136+
dynamic_cast<GroupedBlockScalingFactorLayoutOp*>(
137+
consumer_tv->definition())) {
138+
if (producer_tv != grouped_block_sf_layout->in()) {
139+
auto producer_logical =
140+
TensorDomain::noReductions(producer_tv->getLogicalDomain());
141+
non_mapping_ids.insert(producer_logical.begin(), producer_logical.end());
142+
// we are not mapping anything, `has_consumer_id` doesn't matter.
143+
has_consumer_id = false;
144+
}
134145
}
135146

136147
return std::make_pair(non_mapping_ids, has_consumer_id);

0 commit comments

Comments
 (0)