Skip to content

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Jul 7, 2025

This PR introduces a preliminary support of lowering of the scatter operation rather than falling back to Aten. The primary motivation is to generate a single fused kernel for fusions like SgLangMoeTest.ComputeArgSort.

It is not yet piped through FusionExecutorCache, so nothing should be impacted as long as nvFuser is used through FusionExecutorCache.

Scatter is inherently in-place, which doesn't mix well with the overall semantics of the Fusion IR. Here, from the users' perspective, scatter is provided as an out-of-place operation, like below:

auto tv4 = scatter(tv2, 0, tv1, tv3);

https://github.com/NVIDIA/Fuser/pull/4742/files#diff-a50219bc583905a766ab511e0af91ba8af96a821a93bb19f20d4b550c18a9f5cR49

Here, tv2 and tv4 are different tensors in the fusion. The user is free to use tv2 and tv4 separately in the fusion. However, when generating a CUDA kernel, we would want to implement the operation as an in-place operation, so at the time of lowering, it is validated such that the scatter input is only used by the scatter operation itself. This restriction should be enforced by the fusion segmenter.

Before lowering, the loop domain of tv4 is meant to be updated to use the logical domain of the index tensor. This is currently manually done as shown here.
Edit: I decided to do this from the beginning as part of the TensorDomain constructor.

At the time of lowering, once the validation passes, a new lowering pass, setInplaceAlias, modifies the allocation nodes of the scatter input and output such that the output becomes an alias of the input (except when the output is also a fusion output, in that case the input becomes an alias of the output). I initially considered extending the existing memory reuse pass but decided to add a new separate pass for simplicity.

Once the aliasing is done, then the rest is just a matter of some minor adjustments here and there.

With this PR, the ComputeArgSort can be manually scheduled as shown here. Similarly, the ComputeProblemSizes can also be lowered when the index size is 1 since in that case there's no accumulation. That should correspond to the decode pass.

Note that this PR does not support scatter with multi-dimensional tensors. This is because in PyTorch scatter, non-indexed dimensions are not guaranteed to have the same extents between all the tensors, so there's no ID mapping, meaning there's no indexing path. I think we should represent this as a separate resize op, but not yet done.

#4764 is a follow-up PR to extend this for the accumulation case.

More thorough testing as well as actual automatic scheduling support should be done in future PRs.

Copy link

github-actions bot commented Jul 7, 2025

Review updated until commit 650bf85

Description

  • Implement scatter operation lowering to enable fused kernels

  • Validate scatter constraints for safe in-place code generation

  • Add inplace aliasing pass for scatter input/output tensors

  • Support shared memory and grid-level scatter operations


Changes walkthrough 📝

Relevant files
Enhancement
17 files
codegen.cpp
Update scatter code generation for new interface                 
+6/-1     
lower2device.cpp
Add inplace alias pass to lowering pipeline                           
+5/-0     
allocation.cpp
Handle scatter allocation domain and output allocations   
+11/-13 
index.cpp
Update scatter index lowering with ID-based override         
+10/-9   
inplace_alias.cpp
New pass for scatter inplace allocation aliasing                 
+137/-0 
validation.cpp
Add scatter operation validation logic                                     
+84/-0   
id_model.cpp
Add scatter loop domain mapping in ID model                           
+24/-0   
predicate_indexing.cpp
Add scatter index tensor predicate domains                             
+10/-0   
index_compute.cpp
Update consumer indexing with ID-based override                   
+10/-12 
nodes.cpp
Add initial loop domain to TensorDomain                                   
+15/-7   
kernel.cpp
Exclude aliased and output allocations from parameters     
+6/-1     
logical_domain_map.cpp
Handle scatter in pairwise logical domain mapping               
+9/-0     
indexing.cpp
Implement scatter with custom loop domain                               
+19/-4   
executor.cpp
Skip aliased allocations in buffer info                                   
+7/-0     
python_translate.cpp
Update scatter translation for new interface                         
+1/-1     
translation.cpp
Update scatter record with new inputs                                       
+3/-3     
test_gpu2.cpp
Simplify allocation validation in test                                     
+4/-6     
Bug fix
3 files
indexing.cpp
Remove unsupported indirect indexing restriction                 
+0/-5     
utils.cpp
Fix exprs variable name in compareDomains                               
+6/-5     
tensor_view.cpp
Fix cacheBefore for scatter operations                                     
+11/-10 
Additional files
12 files
CMakeLists.txt +1/-0     
index.h +1/-1     
inplace_alias.h +32/-0   
validation.h +3/-0     
index_compute.h +2/-4     
internal_base_nodes.h +2/-1     
internal_nodes.h +26/-6   
kernel.h +1/-1     
indexing.h +29/-1   
test_id_model.cpp +52/-0   
test_moe.cpp +36/-4   
test_scatter.cpp +217/-77

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The logic for creating allocation expressions has been modified to remove the is_output flag, but the comment still suggests that outputs do not need allocation, which may be inconsistent with the new logic where Allocate nodes are created for fusion outputs.

kir::Allocate* createAllocExpr(AllocationInformation& info) {
  // Note that Allocate nodes are created for fusion outputs too

  TensorView* tv_to_alloc = info.buffer;
  const MemoryType memory_type = tv_to_alloc->getMemoryType();
Performance Concern

The override_index map in lowerDstIndex is constructed with IterDomain* keys, but the original code used integer indices. This change may affect how index overrides are applied and could impact performance or correctness in complex indexing scenarios.

const std::unordered_map<IterDomain*, Val*> override_index = {
    {sop->getIndexedID(), lowered_index}};
auto lowered_out = lowerDstIndex(sop->out(), override_index);
Design Issue

The scatter operation creates a loop domain based on the index tensor's logical domain but skips loop validation. This deliberate bypass of validation could lead to subtle bugs if the domains are not properly equivalent, especially in edge cases not covered by current tests.

// Create the output tensor. The validation of the loop domain needs
// to be skipped as it is not guaranteed to be equivalent to the
// logical domain.
TensorView* out_tensor = IrBuilder::create<TensorView>(
    IrBuilder::create<TensorDomain>(
        /*logical_domain=*/out_logical,
        /*loop_domain=*/out_loop,
        /*contiguity=*/
        TensorDomain::getContiguityFilledWith(out_logical, true),
        /*skip_loop_validation=*/true),
    self->getDataType().value());

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 9, 2025

!test

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 9, 2025

!test

return !val->isA<TensorView>() ||
!isSharedMemory(val->as<TensorView>()) ||
ir_utils::isCpAsyncBulkLoad(val->definition());
}) &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary WAR for #4741. Not sure if this is appropriate. Needs help from @zasdfgbnm

@naoyam naoyam changed the title [WIP] Scatter codegen support Lowering scatter Jul 9, 2025
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just renaming some variables

@naoyam naoyam marked this pull request as ready for review July 30, 2025 01:48
@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2025

@jjsjann123 Ready for review again. Please let me know if anything is ambiguous or unclear. Still many things to iron out, so it's highly likely I missed to explain something important.

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2025

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2025

!test --diff

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still in review.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only tests are remained to be reviewed.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.

I think my feedback are mostly for my own questions and maybe some minor nitpicks on logic. I'm happy with merge it as-is and follow up with smaller PRs to patch up the nitpick (since this one is pretty big already).

@naoyam
Copy link
Collaborator Author

naoyam commented Aug 7, 2025

!test

@naoyam naoyam requested a review from jjsjann123 August 7, 2025 07:09
@naoyam
Copy link
Collaborator Author

naoyam commented Aug 7, 2025

!test

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. sorry about the delay and thanks a lot for bearing with my naive questions. 🙇

@naoyam naoyam merged commit 15bed95 into main Aug 8, 2025
55 checks passed
@naoyam naoyam deleted the scatter branch August 8, 2025 18:33
@naoyam naoyam mentioned this pull request Aug 11, 2025
naoyam added a commit that referenced this pull request Aug 12, 2025
Fixes #4929. 

This is a follow-up bug fix
(#4742 (comment)).

There are actually two bugs. One is index overriding, which caused the
overriding replacement to fail:

Wrong: `T4[((nvfuser_index_t)threadIdx.x)] = T3[0];`
Correct: `T4[__to_index(T0[((nvfuser_index_t)threadIdx.x)])] = T3[0];`

Another bug is missing RAW syncs. The sync analysis needed to be
extended to consider indirect indexing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants