-
Couldn't load subscription status.
- Fork 68
Add fmin_fmax_promotion presegmentation pass #5337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit df9a676 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
| return attribute<BinaryOpType>(1); | ||
| } | ||
|
|
||
| void markUnsafe() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Jacob recommended we get rid of this function, and instead replace the entire Expr with a new one.
|
|
||
| // Full-size statuses | ||
| DEFAULT, | ||
| BAD_BROADCAST, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still trying to understand the analysis, but wondering why we need a separate status for reduction and broadcast. Just having GOOD and BAD not enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is still unclear to me why there is both DEFAULT and GOOD. I also don't understand why we need separate state for broadcasted BAD.
| for (auto input : expr->inputs()) { | ||
| if (auto* in_tv = dynamic_cast<TensorView*>(input)) { | ||
| for (IterDomain* id : in_tv->getLogicalDomain()) { | ||
| IterDomainStatus status = iterMap[id]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is iterMap guaranteed to have a mapping for id? If so, let's use at so that we can mark iterMap as a const ref.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These have changed names but the question is still valid:
Is the map (NanStatusMap) guaranteed to have a mapping?
No, the mapping may not exist for every node. For example:
TensorView* tv1 = max(in0, {0, 1});
TensorView* tv2 = add(in0, in2);
The add node here has 2 inputs, but only the in0 TensorView will have a mapping during analysis. This is what the None state is for, it's the default state for unmapped TV's.
| IterDomainStatus status = iterMap[in_id]; | ||
| auto out_id = p2c[in_id]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you avoid using [] as it's not very clear what is intended. Are you assuming the index has a mapping or are you relying on automatic addition of a new mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am relying on automatic addition of a new mapping to handle unmapped expression inputs. Their states will be "None".
|
|
||
| namespace nvfuser::preseg_passes { | ||
|
|
||
| // IterDomainStatus are attached to IterDomains and propagated with a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure iter domains are the right granularity of the analysis. If one iter domain has a bad status, its tensor should be considered bad as well. Also, reductions remove iter domains, so "bad" iter domains would just disappear from the fusion. It seems to me tensors are the right level of this analysis. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More specific comments below. I think you should focus on explaining the algorithm and really thinking about what state is needed. I agree with @naoyam that it seems like only "good" and "bad" states are needed. Also, why not have an initialization step where all IDs of fusion inputs are marked GOOD instead of NONE?
| expectFMax = true; | ||
| } | ||
|
|
||
| if (testIndex == 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip: in cases like this I typically would create a new class for FMinFMaxPromotionTest instead of an alias. In there I would implement SetUp() and TearDown() and those would hold everything from this current test other than the if (testIndex == *) parts. That lets you directly give a descriptive name to each test, even without parametrization (i.e. you can use TEST_F instead of TEST_P then unless you have further parametrizations to do.
| // Once we identify a target reduction, we perform a downward pass starting from | ||
| // the target's direct input. The pass propagates IterDomainStatus information. | ||
| // At the end, we check all output TV's for bad statuses. If at any point we | ||
| // encounter a node we don't know how to propagate information through, we treat | ||
| // it like to a graph output and fail if it has any incoming bad statuses. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this comment, it would be very instructive to add a couple complete examples where you show a fusion and trace down through the fusion showing how the ID statuses propagate from a given max/min reduction.
| bool expectFMax = false; | ||
|
|
||
| if (testIndex == 1) { | ||
| TensorView* tv3 = add(max(tv0, {0, 1}), tv0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is often clearer to put one operation per line. This lets you mark the axes for each tensor in the fusion. For example, in this case you would have something like
TensorView* tv3 = max(tv0, {0, 1}); // [ rS5{i0}, rS6{i1} ]
// Note: the implicit broadcast tv4 here is not shown in your current code
TensorView* tv4 = broadcast(tv3, {true, true}); // [ bS7{1}, bS8{1} ]
TensorView* tv5 = add(tv4, tv0); // [ iS9{i0}, iS10{i1} ]
TensorView* tv6 = sum(tv5, {0, 1}); // [ rS11{i0}, rS12{i1} ]
// NOTE: tv7 below is not shown currently either
TensorView* tv7 = broadcast(tv6, {true, true}); // [ bS13{i0}, bS14{i1} ]
TensorView* tv8 = add(tv5, tv7); // [ iS15{i0}, iS16{i1} ]
fusion->addOutput(tv8);There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a short comment can help indicate what it is we're testing in each case.
|
Pushed a new algorithm which should handle a lot of the issues with reductions / broadcasting not being supported. The new algorithm focuses on a single source IterDomain at a time, and then propagates information along TensorViews. This solves the issues that arise when tracking IterDomain's through reductions and broadcasts. The current code is messy and needs to be cleaned up. There is one unsolved issue which is handling sibling rewrites. So if we promote one fmax somewhere in the fusion, right now it can break other fmax's. I thought this could easily be solved by doing rewrites in reverse-topological order, however this does not solve the case for sibling expressions. This is tested by the test case #8 right now, which is the only failing test case. |
| if (valMap[expr->input(0)->as<TensorView>()] == ValStatus::DEFAULT || | ||
| valMap[expr->input(0)->as<TensorView>()] == ValStatus::BAD_DEFAULT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
| if (valMap[expr->input(0)->as<TensorView>()] == ValStatus::DEFAULT || | |
| valMap[expr->input(0)->as<TensorView>()] == ValStatus::BAD_DEFAULT) { | |
| auto *it = valMap.find(expr->input(0)->as<TensorView>()); | |
| if (it == valMap.end() || it->second == ValStatus::DEFAULT || | |
| it->second == ValStatus::BAD_DEFAULT) { |
If we expect input to always be found in valMap, then I'd do this instead:
ValStatus in_status = valMap.at(expr->input(0)->as<TensorView>());
if (in_status == ValStatus::DEFAULT || in_status == ValStatus::BAD_DEFAULT) {There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't expect there to always be a value in the map, we utilize the default value being set to "None".
This seems to be a very common comment on this PR, I guess we usually do not use the default value with unordered_map. Let me know if you want me to explicitly check whether a mapping exists (e.g. with .contains()). It's a lot more verbose to do so though.
- Function names start with lowercase letters - Use snake_case instead of camelCase - Add anonymous namespace to file-scoped things
Rewrite and rebase of #5121. Adds a new presegmentation pass "fmin_fmax_promotion" which switches min/max reductions with fmin/fmax reductions where possible. Original motivation on #319.
The new pass does dataflow analysis by attaching an enum to IterDomain's. It flows these downward and checks whether any corrupted "BAD" states end up in the output. Currently can only handle 4 operator types:
For any other operator type, or if at any point we fail to map an IterDomain through an operator, we treat the operator as a fusion output.