-
Notifications
You must be signed in to change notification settings - Fork 69
Lowering scatter #4742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lowering scatter #4742
Conversation
Review updated until commit 650bf85 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
!test |
return !val->isA<TensorView>() || | ||
!isSharedMemory(val->as<TensorView>()) || | ||
ir_utils::isCpAsyncBulkLoad(val->definition()); | ||
}) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a temporary WAR for #4741. Not sure if this is appropriate. Needs help from @zasdfgbnm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just renaming some variables
@jjsjann123 Ready for review again. Please let me know if anything is ambiguous or unclear. Still many things to iron out, so it's highly likely I missed to explain something important. |
!test --diff |
!test --diff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still in review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only tests are remained to be reviewed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall.
I think my feedback are mostly for my own questions and maybe some minor nitpicks on logic. I'm happy with merge it as-is and follow up with smaller PRs to patch up the nitpick (since this one is pretty big already).
!test |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. sorry about the delay and thanks a lot for bearing with my naive questions. 🙇
Fixes #4929. This is a follow-up bug fix (#4742 (comment)). There are actually two bugs. One is index overriding, which caused the overriding replacement to fail: Wrong: `T4[((nvfuser_index_t)threadIdx.x)] = T3[0];` Correct: `T4[__to_index(T0[((nvfuser_index_t)threadIdx.x)])] = T3[0];` Another bug is missing RAW syncs. The sync analysis needed to be extended to consider indirect indexing.
This PR introduces a preliminary support of lowering of the scatter operation rather than falling back to Aten. The primary motivation is to generate a single fused kernel for fusions like SgLangMoeTest.ComputeArgSort.
It is not yet piped through FusionExecutorCache, so nothing should be impacted as long as nvFuser is used through FusionExecutorCache.
Scatter is inherently in-place, which doesn't mix well with the overall semantics of the Fusion IR. Here, from the users' perspective, scatter is provided as an out-of-place operation, like below:
https://github.com/NVIDIA/Fuser/pull/4742/files#diff-a50219bc583905a766ab511e0af91ba8af96a821a93bb19f20d4b550c18a9f5cR49
Here,
tv2
andtv4
are different tensors in the fusion. The user is free to usetv2
andtv4
separately in the fusion. However, when generating a CUDA kernel, we would want to implement the operation as an in-place operation, so at the time of lowering, it is validated such that the scatter input is only used by the scatter operation itself. This restriction should be enforced by the fusion segmenter.Before lowering, the loop domain oftv4
is meant to be updated to use the logical domain of the index tensor. This is currently manually done as shown here.Edit: I decided to do this from the beginning as part of the
TensorDomain
constructor.At the time of lowering, once the validation passes, a new lowering pass,
setInplaceAlias
, modifies the allocation nodes of the scatter input and output such that the output becomes an alias of the input (except when the output is also a fusion output, in that case the input becomes an alias of the output). I initially considered extending the existing memory reuse pass but decided to add a new separate pass for simplicity.Once the aliasing is done, then the rest is just a matter of some minor adjustments here and there.
With this PR, the
ComputeArgSort
can be manually scheduled as shown here. Similarly, theComputeProblemSizes
can also be lowered when the index size is 1 since in that case there's no accumulation. That should correspond to the decode pass.Note that this PR does not support scatter with multi-dimensional tensors. This is because in PyTorch scatter, non-indexed dimensions are not guaranteed to have the same extents between all the tensors, so there's no ID mapping, meaning there's no indexing path. I think we should represent this as a separate resize op, but not yet done.
#4764 is a follow-up PR to extend this for the accumulation case.
More thorough testing as well as actual automatic scheduling support should be done in future PRs.