-
Notifications
You must be signed in to change notification settings - Fork 8
Meta28 optimization upstream #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: shbiswas/upstream_0815
Are you sure you want to change the base?
Meta28 optimization upstream #113
Conversation
|
|
||
| const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); | ||
| const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; | ||
| const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to guard the change by #ifdef USE_ROCM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
| // Compute shared memory size for cta_per_row | ||
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; | ||
| int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to guard the change by #ifdef USE_ROCM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
| || dev_weights.scalar_type() == at::ScalarType::Float; | ||
|
|
||
| if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) | ||
| if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the order matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Order doesnt matter
| // (vector is very large). | ||
| std::sort( | ||
| std::execution::par, | ||
| // std::execution::par, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uncomment this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
|
||
| const int32_t cta_per_row_grid_size = std::min( | ||
| div_round_up(total_unique_indices, kMaxThreads), | ||
| div_round_up(total_unique_indices, (kMaxThreads/4)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add #ifdef USE_ROCM above this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
|
@liligwu changes are tested. Builds successfully. Few backwards and forwards were tested. no regression |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
fbgemm_gpu/cmake/Hip.cmake
Outdated
| list(APPEND HIP_CXX_FLAGS -mf16c) | ||
| list(APPEND HIP_CXX_FLAGS -mfma) | ||
| list(APPEND HIP_CXX_FLAGS -std=c++20) | ||
| list(APPEND HIP_CXX_FLAGS -g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to revert the whole file changes since it only increases build time and diff in this context
| // Now sort the indices by their tags. Use parallel sort for some extra speed | ||
| // (vector is very large). | ||
| std::sort( | ||
| std::execution::par, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be reverted
warp per row wg change
…d opt size and adjusted WG size for L=1 on hip split embed kernel
No description provided.