-
Notifications
You must be signed in to change notification settings - Fork 66
add layout op runtime function #5115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jj/layout_op_PR0_ir_node
Are you sure you want to change the base?
add layout op runtime function #5115
Conversation
Review updated until commit c340720 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
2c992ab
to
1d72d32
Compare
3ccfbde
to
1d72d32
Compare
ebd03f4
to
5ec9d72
Compare
!test |
1 similar comment
!test |
298ea2f
to
a86508c
Compare
!test |
a86508c
to
7c327f6
Compare
c4e65d3
to
4deb4a9
Compare
f1709fb
to
f5b464f
Compare
!test |
Runtime function signature. template < typename T, typename Index_T, int BLOCK_ROW_OUTER, int BLOCK_ROW_INNER, int BLOCK_COL, int UNROLL_FACTOR> __device__ void groupedBlockLayout( T* output, const T* input, const nvfuser_index_t row_idx, const nvfuser_index_t col_idx, const Index_T* expert_offsets, const Index_T* output_offsets, const nvfuser_index_t col_size, const nvfuser_index_t group_size) where: BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL will be translated from BlockScalingFactorLayout, e.g. Block128x4 is translated to 32, 4, 4. This function will be used by codegen for `GroupedBlockScalingFactorLayoutOp` `output` is expected to be the beginning of output buffer, indexing will be done inside the function template with help of `row_idx`, `col_idx`, `expert_offsets`, `output_offsets` and `col_size` Meanwhil, indexing on `input` would have been resolved during device lowering.
a3afee4
to
150d3ee
Compare
f5b464f
to
c340720
Compare
#5118 PR3: enable codegen for layout op
#5115 PR2: add layout op runtime function <- this PR
#5114 PR1: add layout op
Todo for future PRs:
Add vectorization support.